-
Eelco van der Wel authored0d35b41a
import sys
from typing import Any, Dict, ForwardRef, Tuple, Union
if sys.version_info >= (3, 9, 0):
from typing import get_args, get_origin
def evaluate_type(type_: Any, localns) -> Any:
return type_._evaluate(globalns=None, localns=localns, recursive_guard=set())
else:
def get_args(type_annotation: Any) -> Any:
return getattr(type_annotation, "__args__", tuple())
def get_origin(type_annotation: Any) -> Any:
return getattr(type_annotation, "__origin__", None)
def evaluate_type(type_: Any, localns) -> Any:
return type_._evaluate(globalns=None, localns=localns)
def type_to_str(_type: type) -> str:
"""Returns string representation of `_type`, ForwardRef is supported."""
if isinstance(_type, type):
return _type.__name__
elif isinstance(_type, ForwardRef):
return _type.__forward_arg__
def type_or_union_to_tuple(type_: type) -> Tuple[type]:
"""converts Union[TypeA, TypeB] to (TypeA, TypeB)
If type_ is not a union, return (_type,)
"""
if get_origin(type_) == Union:
return tuple(get_args(type_))
return (type_,)
def update_union_forward_refs(field, schema):
# Standard pydantic resolver does not handle Union[ForwardRef(A), ForwardRef(B)] -> Union[A, B]
# TODO log as issue with Pydantic
if get_origin(field.type_) == Union:
new_args = tuple(
evaluate_type(arg, localns=schema)
for arg in get_args(field.type_)
if isinstance(arg, ForwardRef)
)
field.type_.__args__ = new_args
def resolve_forward_refs(schema: Dict[str, type]):
"""
Resolve forward refs for all classes in schema.values()
"""
# Need to make a copy, update_forward_refs does inplace modifications
_schema = schema.copy()
for schema_cls in list(schema.values()):
schema_cls.update_forward_refs(**_schema)
for field in schema_cls.__edge_fields__.values():
update_union_forward_refs(field, _schema)