utils.py 1.97 KiB
import sys
from typing import Any, Dict, ForwardRef, Tuple, Union
if sys.version_info >= (3, 9, 0):
    from typing import get_args, get_origin
    def evaluate_type(type_: Any, localns) -> Any:
        return type_._evaluate(globalns=None, localns=localns, recursive_guard=set())
else:
    def get_args(type_annotation: Any) -> Any:
        return getattr(type_annotation, "__args__", tuple())
    def get_origin(type_annotation: Any) -> Any:
        return getattr(type_annotation, "__origin__", None)
    def evaluate_type(type_: Any, localns) -> Any:
        return type_._evaluate(globalns=None, localns=localns)
def type_to_str(_type: type) -> str:
    """Returns string representation of `_type`, ForwardRef is supported."""
    if isinstance(_type, type):
        return _type.__name__
    elif isinstance(_type, ForwardRef):
        return _type.__forward_arg__
def type_or_union_to_tuple(type_: type) -> Tuple[type]:
    """converts Union[TypeA, TypeB] to (TypeA, TypeB)
    If type_ is not a union, return (_type,)
    """
    if get_origin(type_) == Union:
        return tuple(get_args(type_))
    return (type_,)
def update_union_forward_refs(field, schema):
    # Standard pydantic resolver does not handle Union[ForwardRef(A), ForwardRef(B)] -> Union[A, B]
    # TODO log as issue with Pydantic
    if get_origin(field.type_) == Union:
        new_args = tuple(
            evaluate_type(arg, localns=schema)
            for arg in get_args(field.type_)
            if isinstance(arg, ForwardRef)
        field.type_.__args__ = new_args
def resolve_forward_refs(schema: Dict[str, type]):
    """
    Resolve forward refs for all classes in schema.values()
    """
    # Need to make a copy, update_forward_refs does inplace modifications
    _schema = schema.copy()
    for schema_cls in list(schema.values()):
        schema_cls.update_forward_refs(**_schema)
        for field in schema_cls.__edge_fields__.values():
            update_union_forward_refs(field, _schema)