Source code for hop.models

from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass, field
import email
import fastavro
from io import BytesIO
import json
import re
from typing import Any, Dict, List, Union
import collections.abc

import xmltodict

from . import plugins
from . import avro_utils

JSONType = Union[str, int, float, bool, None, Dict[str, Any], List[Any]]


def format_name(cls):
    if hasattr(cls, "format_name"):
        return cls.format_name
    return cls.__name__.lower()


[docs] @dataclass class MessageModel(ABC): """An abstract message model. """ def __bytes__(self): """Produce the canonical representation of this message. """ return json.dumps(asdict(self)).encode("utf-8")
[docs] def serialize(self): """Wrap the message with its format and content, for transmission of Kafka. Returns: A dictionary with "format" and "content" keys. The value stored under "format" is the format label. The value stored under "content" is the actual encoded data. """ # by default, encode using JSON return {"format": format_name(type(self)), "content": bytes(self), }
[docs] @classmethod def deserialize(cls, data): """Unwrap a message produced by serialize() (the "content" value). Returns: An instance of the model class. """ # corresponding to the default serialize implementation, upack from JSON return cls(**json.loads(data.decode("utf-8")))
[docs] @classmethod def load_file(cls, filename): """Create a new message model from a file. Args: filename: The path to a file. Returns: The message model. """ with open(filename, "r") as f: return cls.load(f)
[docs] @classmethod @abstractmethod def load(cls, input_): """Create a new message model from a file object or string. This base implementation has no functionality and should not be called. Args: input_: A file object or string. Returns: The message model. """ raise NotImplementedError("MessageModel.load() should not be called")
[docs] @dataclass class VOEvent(MessageModel): """Defines a VOEvent 2.0 structure. Implements the schema defined by: http://www.ivoa.net/Documents/VOEvent/20110711/ """ ivorn: str role: str = "observation" version: str = "2.0" Who: dict = field(default_factory=dict) What: dict = field(default_factory=dict) WhereWhen: dict = field(default_factory=dict) How: dict = field(default_factory=dict) Why: dict = field(default_factory=dict) Citations: dict = field(default_factory=dict) Description: dict = field(default_factory=dict) Reference: dict = field(default_factory=dict) def __str__(self): return json.dumps(asdict(self), indent=2) def __bytes__(self): # TODO: this isn't really suitable, as the output should be the original XML format # That requires some massaging of the data to restore information discarded by xmltodict. return str(self).encode("utf-8")
[docs] def serialize(self): """Wrap the message with its format and content. Returns: A dictionary with "format" and "content" keys. The value stored under "format" is the format label. The value stored under "content" is the actual encoded data. """ # by default, encode using JSON return {"format": format_name(type(self)), "content": json.dumps(asdict(self)).encode("utf-8"), }
[docs] @classmethod def load(cls, xml_input): """Create a new VOEvent from an XML-formatted VOEvent. Args: xml_input: A file object, string, or generator. Returns: The VOEvent. """ vo = xmltodict.parse(xml_input, attr_prefix="") # enter root and remove XML-specific namespaces return cls(**{k: v for k, v in vo["voe:VOEvent"].items() if ":" not in k})
[docs] @classmethod def load_file(cls, filename): """Create a new VOEvent from an XML-formatted VOEvent file. Args: filename: Name of the VOEvent file. Returns: The VOEvent. """ with open(filename, "rb") as f: return cls.load(f)
[docs] @dataclass class GCNTextNotice(MessageModel): """A GCN Notice in the plain text format. The original message data is stored in the raw property, and the parsed message is available via the fields property, which is a dictionary. """ raw: bytes fields: dict def __bytes__(self): return self.raw
[docs] @classmethod def deserialize(cls, data): fields = {} line_pattern = re.compile("([^:]*): *(.*)") last_key = None for line in data.decode("utf-8").strip().splitlines(): m = re.match(line_pattern, line) if m is None: # if the line does not match the "key: value" format, # we assume it is a continuation of a previous value if last_key is not None: fields[last_key] += '\n' + line.lstrip() else: # if the key is repeated, # we treat the new value as a continuation of the previous one last_key = m.group(1).strip().lower() value = m.group(2).strip() if last_key in fields: fields[last_key] += '\n' + value else: fields[last_key] = value return cls(raw=data, fields=fields)
[docs] @classmethod def load(cls, input): if hasattr(input, "read"): raw = input.read() else: raw = input return cls.deserialize(raw)
[docs] @classmethod def load_file(cls, filename): with open(filename, "rb") as f: return cls.load(f)
[docs] @dataclass class GCNCircular(MessageModel): """Defines a GCN Circular structure. The parsed GCN circular is formatted as a dictionary with the following schema: {'headers': {'title': ..., 'number': ..., ...}, 'body': ...} """ header: dict body: str format_name = "circular" def __str__(self): headers = [(name.upper() + ":").ljust(9) + val for name, val in self.header.items()] return "\n".join(headers + ["", self.body]) def __bytes__(self): return str(self).encode("utf-8")
[docs] def serialize(self): """Wrap the message with its format and content. Returns: A dictionary with "format" and "content" keys. The value stored under "format" is the format label. The value stored under "content" is the actual encoded data. """ # by default, encode using JSON return {"format": format_name(type(self)), "content": json.dumps(asdict(self)).encode("utf-8"), }
[docs] @classmethod def load(cls, email_input): """Create a new GCNCircular from an RFC 822 formatted circular. Args: email_input: A file object or string. Returns: The GCNCircular. """ if hasattr(email_input, "read"): message = email.message_from_file(email_input) else: message = email.message_from_string(email_input) # format gcn circular into header/body return cls( header={title.lower(): content for title, content in message.items()}, body=message.get_payload(), )
[docs] @dataclass class Blob(MessageModel): """Defines an opaque message blob. """ content: bytes def __str__(self): return str(self.content) def __bytes__(self): return self.content
[docs] @classmethod def deserialize(cls, data): return cls(content=data)
[docs] @classmethod def load(cls, blob_input): """Create a blob message from input data. Args: blob_input: The unstructured message data (bytes) or file object. Returns: The Blob. """ if hasattr(blob_input, "read"): raw = blob_input.read() else: raw = blob_input return cls(content=raw)
[docs] @classmethod def load_file(cls, filename): """Create a new message model from a file. Args: filename: The path to a file. Returns: The message model. """ # arbitrary data should not be subjected to line-ending conversion, etc. with open(filename, "rb") as f: return cls.load(f)
[docs] @dataclass class JSONBlob(MessageModel): """Defines an unformatted message blob. """ content: JSONType format_name = "json" def __str__(self): return str(self.content) def __bytes__(self): return json.dumps(self.content).encode("utf-8")
[docs] @classmethod def deserialize(cls, data): return cls(content=json.loads(data.decode("utf-8")))
[docs] @classmethod def load(cls, blob_input): """Create a blob message from input text. Args: blob_input: The unstructured message text or file object. Returns: The Blob. """ if hasattr(blob_input, "read"): raw = blob_input.read() else: raw = blob_input if isinstance(raw, bytes): raw = raw.decode("utf-8") return cls(content=json.loads(raw))
[docs] @dataclass class AvroBlob(MessageModel): """Defines an unformatted message blob. """ # serializing as Avro supports essentially the same types as JSON content: Union[JSONType, List[JSONType]] schema: dict = None format_name = "avro" single_record: bool = True def __init__(self, content: Union[JSONType, List[JSONType]], schema: dict = None, single_record: bool = True): if not single_record and not isinstance(content, collections.abc.Sequence): raise TypeError("AvroBlob requires content to be a sequence of records when " "single_record = False") self.content = content self.schema = schema self.single_record = single_record def __str__(self): return str(self.content) def __bytes__(self): if self.single_record: records = [self.content] else: records = self.content if self.schema is None: # make up an ad-hoc schema self.schema = avro_utils.SchemaGenerator().find_common_type(records) buffer = BytesIO() fastavro.writer(buffer, self.schema, records ) return buffer.getvalue() @classmethod def _read_avro(cls, stream, single_record=True): extracted = [] reader = fastavro.reader(stream) for record in reader: extracted.append(record) if single_record: if len(extracted) > 1: raise TypeError("AvroBlob requires content to be a single " "record when single_record = True, number of " "records in message being deserialized is " f"{len(extracted)}.") content = extracted[0] else: content = extracted return cls(content=content, schema=reader.writer_schema, single_record=single_record)
[docs] @classmethod def deserialize(cls, data, single_record=True): return cls._read_avro(BytesIO(data), single_record=single_record)
[docs] @classmethod def load(cls, blob_input, single_record=True): """Create a blob message from input avro data. Args: blob_input: The encoded Avro data or file object. single_record: True if input avro data only contains one record. True by default. Returns: The Blob. """ if hasattr(blob_input, "read"): raw = blob_input else: if not isinstance(blob_input, bytes): raise TypeError raw = BytesIO(blob_input) return cls._read_avro(raw, single_record=single_record)
[docs] @classmethod def load_file(cls, filename, single_record=True): """Create a new message model from a file. Args: filename: The path to a file. single_record: True if input avro data only contains one record. True by default. Returns: The message model. """ with open(filename, "rb") as f: return cls.load(f, single_record=single_record)
def __eq__(self, other): if type(self) is not type(other): return False # compare only content, not schemas return self.content == other.content def __hash__(self): raise NotImplementedError("AvroBlob objects are not hashable")
[docs] @dataclass class ExternalMessage(MessageModel): """Defines a message which refers to data stored externally at some URL """ url: str format_name = "external"
[docs] @classmethod def load(cls, input): """Create a blob message from input text. Args: blob_input: The unstructured message text or file object. Returns: The Blob. """ if hasattr(input, "read"): raw = input.read() else: raw = input if isinstance(raw, bytes): raw = raw.decode("utf-8") decoded = json.loads(raw) return cls(url=decoded["url"])
@plugins.register def get_models(): model_classes = [ VOEvent, GCNTextNotice, GCNCircular, Blob, JSONBlob, AvroBlob, ExternalMessage, ] return {format_name(cls): cls for cls in model_classes}