from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass, field
import email
import fastavro
from io import BytesIO
import json
from typing import Any, Dict, List, Union
import collections.abc
import xmltodict
from . import plugins
from . import avro_utils
JSONType = Union[str, int, float, bool, None, Dict[str, Any], List[Any]]
def format_name(cls):
if hasattr(cls, "format_name"):
return cls.format_name
return cls.__name__.lower()
[docs]@dataclass
class MessageModel(ABC):
"""An abstract message model.
"""
[docs] def serialize(self):
"""Wrap the message with its format and content.
Returns:
A dictionary with "format" and "content" keys.
The value stored under "format" is the format label.
The value stored under "content" is the actual encoded data.
"""
# by default, encode using JSON
return {"format": format_name(type(self)),
"content": json.dumps(asdict(self)).encode("utf-8")
}
[docs] @classmethod
def deserialize(cls, data):
"""Unwrap a message produced by serialize() (the "content" value).
Returns:
An instance of the model class.
"""
# corresponding to the default serialize implementation, upack from JSON
return cls(**json.loads(data.decode("utf-8")))
[docs] @classmethod
def load_file(cls, filename):
"""Create a new message model from a file.
Args:
filename: The path to a file.
Returns:
The message model.
"""
with open(filename, "r") as f:
return cls.load(f)
[docs] @classmethod
@abstractmethod
def load(cls, input_):
"""Create a new message model from a file object or string.
This base implementation has no functionality and should not be called.
Args:
input_: A file object or string.
Returns:
The message model.
"""
raise NotImplementedError("MessageModel.load() should not be called")
[docs]@dataclass
class VOEvent(MessageModel):
"""Defines a VOEvent 2.0 structure.
Implements the schema defined by:
http://www.ivoa.net/Documents/VOEvent/20110711/
"""
ivorn: str
role: str = "observation"
version: str = "2.0"
Who: dict = field(default_factory=dict)
What: dict = field(default_factory=dict)
WhereWhen: dict = field(default_factory=dict)
How: dict = field(default_factory=dict)
Why: dict = field(default_factory=dict)
Citations: dict = field(default_factory=dict)
Description: dict = field(default_factory=dict)
Reference: dict = field(default_factory=dict)
def __str__(self):
return json.dumps(asdict(self), indent=2)
[docs] @classmethod
def load(cls, xml_input):
"""Create a new VOEvent from an XML-formatted VOEvent.
Args:
xml_input: A file object, string, or generator.
Returns:
The VOEvent.
"""
vo = xmltodict.parse(xml_input, attr_prefix="")
# enter root and remove XML-specific namespaces
return cls(**{k: v for k, v in vo["voe:VOEvent"].items() if ":" not in k})
[docs] @classmethod
def load_file(cls, filename):
"""Create a new VOEvent from an XML-formatted VOEvent file.
Args:
filename: Name of the VOEvent file.
Returns:
The VOEvent.
"""
with open(filename, "rb") as f:
return cls.load(f)
[docs]@dataclass
class GCNCircular(MessageModel):
"""Defines a GCN Circular structure.
The parsed GCN circular is formatted as a dictionary with
the following schema:
{'headers': {'title': ..., 'number': ..., ...}, 'body': ...}
"""
header: dict
body: str
format_name = "circular"
def __str__(self):
headers = [(name.upper() + ":").ljust(9) + val for name, val in self.header.items()]
return "\n".join(headers + ["", self.body])
[docs] @classmethod
def load(cls, email_input):
"""Create a new GCNCircular from an RFC 822 formatted circular.
Args:
email_input: A file object or string.
Returns:
The GCNCircular.
"""
if hasattr(email_input, "read"):
message = email.message_from_file(email_input)
else:
message = email.message_from_string(email_input)
# format gcn circular into header/body
return cls(
header={title.lower(): content for title, content in message.items()},
body=message.get_payload(),
)
[docs]@dataclass
class Blob(MessageModel):
"""Defines an opaque message blob.
"""
content: bytes
def __str__(self):
return str(self.content)
[docs] def serialize(self):
return {"format": format_name(type(self)), "content": self.content}
[docs] @classmethod
def deserialize(cls, data):
return cls(content=data)
[docs] @classmethod
def load(cls, blob_input):
"""Create a blob message from input data.
Args:
blob_input: The unstructured message data (bytes) or file object.
Returns:
The Blob.
"""
if hasattr(blob_input, "read"):
raw = blob_input.read()
else:
raw = blob_input
return cls(content=raw)
[docs]@dataclass
class JSONBlob(MessageModel):
"""Defines an unformatted message blob.
"""
content: JSONType
format_name = "json"
def __str__(self):
return str(self.content)
[docs] def serialize(self):
return {"format": format_name(type(self)),
"content": json.dumps(self.content).encode("utf-8")
}
[docs] @classmethod
def deserialize(cls, data):
return cls(content=json.loads(data.decode("utf-8")))
[docs] @classmethod
def load(cls, blob_input):
"""Create a blob message from input text.
Args:
blob_input: The unstructured message text or file object.
Returns:
The Blob.
"""
if hasattr(blob_input, "read"):
raw = blob_input.read()
else:
raw = blob_input
if isinstance(raw, bytes):
raw = raw.decode("utf-8")
return cls(content=json.loads(raw))
[docs]@dataclass
class AvroBlob(MessageModel):
"""Defines an unformatted message blob.
"""
content: List[JSONType] # serializing as Avro supports essentially the same types as JSON
schema: dict = None
format_name = "avro"
def __init__(self, content: List[JSONType], schema: dict = None):
if not isinstance(content, collections.abc.Sequence):
raise TypeError("AvroBlob requires content to be a sequence of records")
self.content = content
self.schema = schema
def __str__(self):
return str(self.content)
[docs] def serialize(self):
"""Wrap the message with its format and content.
Returns:
A dictionary with "format" and "content" keys.
"""
if self.schema is None: # make up an ad-hoc schema
self.schema = avro_utils.SchemaGenerator().find_common_type(self.content)
stringio = BytesIO()
fastavro.writer(stringio,
self.schema,
self.content
)
return {"format": format_name(type(self)), "content": stringio.getvalue()}
@classmethod
def _read_avro(cls, stream):
extracted = []
reader = fastavro.reader(stream)
for record in reader:
extracted.append(record)
return cls(content=extracted, schema=reader.writer_schema)
[docs] @classmethod
def deserialize(cls, data):
return cls._read_avro(BytesIO(data))
[docs] @classmethod
def load(cls, blob_input):
"""Create a blob message from input avro data.
Args:
blob_input: The encoded Avro data or file object.
Returns:
The Blob.
"""
if hasattr(blob_input, "read"):
raw = blob_input
else:
if not isinstance(blob_input, bytes):
raise TypeError
raw = BytesIO(blob_input)
return cls._read_avro(raw)
[docs] @classmethod
def load_file(cls, filename):
"""Create a new message model from a file.
Args:
filename: The path to a file.
Returns:
The message model.
"""
with open(filename, "rb") as f:
return cls.load(f)
def __eq__(self, other):
if type(self) != type(other):
return False
# compare only content, not schemas
return self.content == other.content
def __hash__(self):
raise NotImplementedError("AvroBlob objects are not hashable")
@plugins.register
def get_models():
model_classes = [
VOEvent,
GCNCircular,
Blob,
JSONBlob,
AvroBlob,
]
return {format_name(cls): cls for cls in model_classes}