-
Eelco van der Wel authorede9c082c8
"""
itembase.py contains the schema baseclasses, extended from Pydantic.
This file only contains core schema functionality
and does not to be changed to add new schema definitions.
"""
import uuid
from datetime import datetime, timezone
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Dict,
Generic,
List,
Tuple,
Type,
TypeVar,
Union,
no_type_check,
)
from pydantic import BaseModel, Extra, PrivateAttr, validator
from pydantic.fields import Field, FieldInfo, ModelField
from pydantic.generics import GenericModel
from pydantic.main import ModelMetaclass
from typing_extensions import dataclass_transform
from .utils import get_args, get_origin, type_or_union_to_tuple, type_to_str
# Typing variables
PodType = Union[bool, str, int, float, datetime]
TargetType = TypeVar("TargetType", bound="ItemBase")
ItemType = TypeVar("ItemType")
POD_TYPES: Dict[type, str] = {
bool: "Bool",
str: "Text",
int: "Integer",
float: "Real",
datetime: "DateTime",
}
def _field_is_property(field: ModelField) -> bool:
return field.outer_type_ in POD_TYPES and not field.name.startswith("_")
def _field_is_edge(field: ModelField) -> bool:
try:
if field.name.startswith("_") or get_origin(field.outer_type_) != list:
return False
args = get_args(field.outer_type_)
if args is None or len(args) != 1:
return False
return True
except Exception:
return False
@dataclass_transform(kw_only_default=True, field_descriptors=(Field, FieldInfo))
class _ItemMeta(ModelMetaclass):
"""
Metaclass for ItemBase that adds required class and instance properties.
Note: _itembase_private_attrs is a workaround for hiding private attributes
generated by pyright/pylance signature. This is only a cosmetic difference,
private attributes can be defined on a model with `PrivateAttr(default)` value.
See: https://github.com/pydantic/pydantic/discussions/4563
7172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
"""
_itembase_private_attrs = {
"__edges__": PrivateAttr(default_factory=dict),
"_in_pod": PrivateAttr(False),
"_new_edges": PrivateAttr(default_factory=list),
"_date_local_modified": PrivateAttr(default_factory=dict),
"_original_properties": PrivateAttr(default_factory=dict),
}
@no_type_check
def __new__(mcs, name, bases, namespace, **kwargs) -> Any:
cls = super().__new__(mcs, name, bases, namespace, **kwargs)
cls.__property_fields__ = {k: v for k, v in cls.__fields__.items() if _field_is_property(v)}
cls.__edge_fields__ = {k: v for k, v in cls.__fields__.items() if _field_is_edge(v)}
# For backwards compatibility with old ItemBase
cls.properties = list(cls.__property_fields__.keys())
cls.edges = list(cls.__edge_fields__.keys())
# Add private instance attributes from _itembase_private_attrs
cls.__private_attributes__ = {
**cls.__private_attributes__,
**mcs._itembase_private_attrs,
}
cls.__slots__ = cls.__slots__ | mcs._itembase_private_attrs.keys()
return cls
class Edge(GenericModel, Generic[TargetType], smart_union=True, copy_on_model_validation=False):
source: "ItemBase"
target: TargetType
name: str
def __init__(self, source: Any = None, target: TargetType = None, name: str = None) -> None:
super().__init__(source=source, target=target, name=name)
# Workaround, occasionally validators will set source or target with a clone
self.source = source
self.target = target
@classmethod
def get_target_types(cls) -> Tuple[type]:
target_annotation = cls.__annotations__["target"]
if target_annotation == TargetType:
return tuple()
return type_or_union_to_tuple(target_annotation)
@classmethod
def get_target_types_as_str(cls) -> Tuple[str]:
return tuple(type_to_str(t) for t in cls.get_target_types())
@validator("target", pre=True, allow_reuse=True)
def validate_target(cls, val: Any) -> TargetType:
target_types = cls.get_target_types()
if len(target_types) == 0:
# cls has no target type annotations, validator always succeeds
return val
if isinstance(val, target_types):
return val
ttype_display = cls.__fields__["target"]._type_display()
raise ValueError(f"target with type `{type(val).__name__}` is not a `{ttype_display}`")
def __eq__(self, other):
if (
isinstance(other, Edge)
and self.name == other.name
and self.source is other.source
and self.target is other.target
):
return True