422 lines
12 KiB
422 lines
12 KiB
import logging
from typing import List
from binascii import hexlify, unhexlify
from asn1crypto.keys import PublicKeyInfo
from coincurve import PublicKey as cPublicKey
from google.protobuf.json_format import MessageToDict
from google.protobuf.message import DecodeError
from hachoir.core.log import log as hachoir_log
from hachoir.parser import createParser as binary_file_parser
from hachoir.metadata import extractMetadata as binary_file_metadata
from hub.schema import compat
from hub.schema.base import Signable
from hub.schema.mime_types import guess_media_type, guess_stream_type
from hub.schema.attrs import (
Source, Playable, Dimmensional, Fee, Image, Video, Audio,
LanguageList, LocationList, ClaimList, ClaimReference, TagList
from hub.schema.types.v2.claim_pb2 import Claim as ClaimMessage
from hub.error import InputValueIsNoneError
hachoir_log.use_print = False
log = logging.getLogger(__name__)
class Claim(Signable):
STREAM = 'stream'
CHANNEL = 'channel'
COLLECTION = 'collection'
REPOST = 'repost'
__slots__ = ()
message_class = ClaimMessage
def claim_type(self) -> str:
return self.message.WhichOneof('type')
def get_message(self, type_name):
message = getattr(self.message, type_name)
if self.claim_type is None:
if self.claim_type != type_name:
raise ValueError(f'Claim is not a {type_name}.')
return message
def is_stream(self):
return self.claim_type == self.STREAM
def stream(self) -> 'Stream':
return Stream(self)
def is_channel(self):
return self.claim_type == self.CHANNEL
def channel(self) -> 'Channel':
return Channel(self)
def is_repost(self):
return self.claim_type == self.REPOST
def repost(self) -> 'Repost':
return Repost(self)
def is_collection(self):
return self.claim_type == self.COLLECTION
def collection(self) -> 'Collection':
return Collection(self)
def from_bytes(cls, data: bytes) -> 'Claim':
return super().from_bytes(data)
except DecodeError:
claim = cls()
if data[0] == ord('{'):
claim.version = 0
compat.from_old_json_schema(claim, data)
elif data[0] not in (0, 1):
claim.version = 1
compat.from_types_v1(claim, data)
return claim
class BaseClaim:
__slots__ = 'claim', 'message'
claim_type = None
object_fields = 'thumbnail',
repeat_fields = 'tags', 'languages', 'locations'
def __init__(self, claim: Claim = None):
self.claim = claim or Claim()
self.message = self.claim.get_message(self.claim_type)
def to_dict(self):
claim = MessageToDict(self.claim.message, preserving_proto_field_name=True)
if 'languages' in claim:
claim['languages'] = self.langtags
if 'locations' in claim:
claim['locations'] = [l.to_dict() for l in self.locations]
return claim
def none_check(self, kwargs):
for key, value in kwargs.items():
if value is None:
raise InputValueIsNoneError(key)
def update(self, **kwargs):
for key in list(kwargs):
for field in self.object_fields:
if key.startswith(f'{field}_'):
attr = getattr(self, field)
setattr(attr, key[len(f'{field}_'):], kwargs.pop(key))
for l in self.repeat_fields:
field = getattr(self, l)
if kwargs.pop(f'clear_{l}', False):
del field[:]
items = kwargs.pop(l, None)
if items is not None:
if isinstance(items, str):
elif isinstance(items, list):
raise ValueError(f"Unknown {l} value: {items}")
for key, value in kwargs.items():
setattr(self, key, value)
def title(self) -> str:
return self.claim.message.title
def title(self, title: str):
self.claim.message.title = title
def description(self) -> str:
return self.claim.message.description
def description(self, description: str):
self.claim.message.description = description
def thumbnail(self) -> Source:
return Source(self.claim.message.thumbnail)
def tags(self) -> List[str]:
return TagList(self.claim.message.tags)
def languages(self) -> LanguageList:
return LanguageList(self.claim.message.languages)
def langtags(self) -> List[str]:
return [l.langtag for l in self.languages]
def locations(self) -> LocationList:
return LocationList(self.claim.message.locations)
class Stream(BaseClaim):
__slots__ = ()
claim_type = Claim.STREAM
object_fields = BaseClaim.object_fields + ('source',)
def to_dict(self):
claim = super().to_dict()
if 'source' in claim:
if 'hash' in claim['source']:
claim['source']['hash'] = self.source.file_hash
if 'sd_hash' in claim['source']:
claim['source']['sd_hash'] = self.source.sd_hash
elif 'bt_infohash' in claim['source']:
claim['source']['bt_infohash'] = self.source.bt_infohash
if 'media_type' in claim['source']:
claim['stream_type'] = guess_stream_type(claim['source']['media_type'])
fee = claim.get('fee', {})
if 'address' in fee:
fee['address'] = self.fee.address
if 'amount' in fee:
fee['amount'] = str(self.fee.amount)
return claim
def update(self, file_path=None, height=None, width=None, duration=None, **kwargs):
if kwargs.pop('clear_fee', False):
kwargs.pop('fee_address', None),
kwargs.pop('fee_currency', None),
kwargs.pop('fee_amount', None)
if 'sd_hash' in kwargs:
self.source.sd_hash = kwargs.pop('sd_hash')
elif 'bt_infohash' in kwargs:
self.source.bt_infohash = kwargs.pop('bt_infohash')
if 'file_name' in kwargs:
self.source.name = kwargs.pop('file_name')
if 'file_hash' in kwargs:
self.source.file_hash = kwargs.pop('file_hash')
stream_type = None
if file_path is not None:
stream_type = self.source.update(file_path=file_path)
elif self.source.name:
self.source.media_type, stream_type = guess_media_type(self.source.name)
elif self.source.media_type:
stream_type = guess_stream_type(self.source.media_type)
if 'file_size' in kwargs:
self.source.size = kwargs.pop('file_size')
if self.stream_type is not None and self.stream_type != stream_type:
if stream_type in ('image', 'video', 'audio'):
media = getattr(self, stream_type)
media_args = {'file_metadata': None}
if file_path is not None and not all((duration, width, height)):
media_args['file_metadata'] = binary_file_metadata(binary_file_parser(file_path))
log.exception('Could not read file metadata.')
if isinstance(media, Playable):
media_args['duration'] = duration
if isinstance(media, Dimmensional):
media_args['height'] = height
media_args['width'] = width
def author(self) -> str:
return self.message.author
def author(self, author: str):
self.message.author = author
def license(self) -> str:
return self.message.license
def license(self, license: str):
self.message.license = license
def license_url(self) -> str:
return self.message.license_url
def license_url(self, license_url: str):
self.message.license_url = license_url
def release_time(self) -> int:
return self.message.release_time
def release_time(self, release_time: int):
self.message.release_time = release_time
def fee(self) -> Fee:
return Fee(self.message.fee)
def has_fee(self) -> bool:
return self.message.HasField('fee')
def has_source(self) -> bool:
return self.message.HasField('source')
def source(self) -> Source:
return Source(self.message.source)
def stream_type(self) -> str:
return self.message.WhichOneof('type')
def image(self) -> Image:
return Image(self.message.image)
def video(self) -> Video:
return Video(self.message.video)
def audio(self) -> Audio:
return Audio(self.message.audio)
class Channel(BaseClaim):
__slots__ = ()
claim_type = Claim.CHANNEL
object_fields = BaseClaim.object_fields + ('cover',)
repeat_fields = BaseClaim.repeat_fields + ('featured',)
def to_dict(self):
claim = super().to_dict()
claim['public_key'] = self.public_key
if 'featured' in claim:
claim['featured'] = self.featured.ids
return claim
def public_key(self) -> str:
return hexlify(self.public_key_bytes).decode()
def public_key(self, sd_public_key: str):
self.message.public_key = unhexlify(sd_public_key.encode())
def public_key_bytes(self) -> bytes:
if len(self.message.public_key) == 33:
return self.message.public_key
public_key_info = PublicKeyInfo.load(self.message.public_key)
public_key = cPublicKey(public_key_info.native['public_key'])
return public_key.format(compressed=True)
def public_key_bytes(self, public_key: bytes):
self.message.public_key = public_key
def email(self) -> str:
return self.message.email
def email(self, email: str):
self.message.email = email
def website_url(self) -> str:
return self.message.website_url
def website_url(self, website_url: str):
self.message.website_url = website_url
def cover(self) -> Source:
return Source(self.message.cover)
def featured(self) -> ClaimList:
return ClaimList(self.message.featured)
class Repost(BaseClaim):
__slots__ = ()
claim_type = Claim.REPOST
def reference(self) -> ClaimReference:
return ClaimReference(self.message)
class Collection(BaseClaim):
__slots__ = ()
claim_type = Claim.COLLECTION
repeat_fields = BaseClaim.repeat_fields + ('claims',)
def to_dict(self):
claim = super().to_dict()
if claim.pop('claim_references', None):
claim['claims'] = self.claims.ids
return claim
def claims(self) -> ClaimList:
return ClaimList(self.message)