|
| 1 | +import types |
| 2 | +import sys |
| 3 | +import functools |
| 4 | +import operator |
| 5 | +import collections.abc |
| 6 | +from typing import Any, ForwardRef, Literal |
| 7 | + |
| 8 | +from typing_extensions import Unpack, get_origin |
| 9 | + |
| 10 | +from ._types import GenericAliasProto |
| 11 | +from ._utils import _is_param_expr |
| 12 | +from typing_inspection import typing_objects |
| 13 | + |
| 14 | + |
| 15 | + |
| 16 | +class UnevaluatedTypeHint(Exception): |
| 17 | + """The type hint wasn't evaluated as it still contains forward references.""" |
| 18 | + |
| 19 | + forward_arg: ForwardRef | str |
| 20 | + """The forward reference that wasn't evaluated.""" |
| 21 | + |
| 22 | + def __init__(self, forward_arg: ForwardRef | str) -> None: |
| 23 | + self.forward_arg = forward_arg |
| 24 | + |
| 25 | +class TypeHintVisitor: |
| 26 | + |
| 27 | + def visit(self, hint: Any) -> None: |
| 28 | + if typing_objects.is_paramspecargs(hint) or typing_objects.is_paramspeckwargs(hint): |
| 29 | + return self.visit_bare_hint(hint) |
| 30 | + origin = get_origin(hint) |
| 31 | + if typing_objects.is_generic(origin): |
| 32 | + # `get_origin()` returns `Generic` if `hint` is `typing.Generic` (or `Generic[...]`). |
| 33 | + raise ValueError(f'{hint} is invalid in an annotation expression') |
| 34 | + |
| 35 | + if origin is not None: |
| 36 | + if hint in typing_objects.DEPRECATED_ALIASES: |
| 37 | + # For *bare* deprecated aliases (such as `typing.List`), `get_origin()` returns the |
| 38 | + # actual type (such as `list`). As such, we treat `hint` as a bare hint. |
| 39 | + self.visit_bare_hint(hint) |
| 40 | + elif sys.version_info >= (3, 10) and origin is types.UnionType: |
| 41 | + self.visit_union(hint) |
| 42 | + else: |
| 43 | + self.visit_generic_alias(hint, origin) |
| 44 | + else: |
| 45 | + self.visit_bare_hint(hint) |
| 46 | + |
| 47 | + # origin = get_origin(hint) |
| 48 | + # if origin in DEPRECATED_ALIASES.values() and not isinstance(hint, types.GenericAlias): |
| 49 | + # # hint is a deprecated generic alias, e.g. `List[int]`. |
| 50 | + # # `get_origin(List[int])` returns `list`, but we want to preserve |
| 51 | + # # `List` as the actual origin. |
| 52 | + |
| 53 | + def visit_generic_alias(self, hint: GenericAliasProto, origin: Any) -> None: |
| 54 | + if not typing_objects.is_literal(origin): |
| 55 | + # Note: it is important to use `hint.__args__` instead of `get_args()` as |
| 56 | + # they differ for some typing forms (e.g. `Annotated`, `Callable`). |
| 57 | + # `hint.__args__` should be guaranteed to only contain other annotation expressions. |
| 58 | + for arg in hint.__args__: |
| 59 | + self.visit(arg) |
| 60 | + |
| 61 | + if sys.version_info >= (3, 10): |
| 62 | + def visit_union(self, hint: types.UnionType) -> None: |
| 63 | + for arg in hint.__args__: |
| 64 | + self.visit(arg) |
| 65 | + |
| 66 | + def visit_bare_hint(self, hint: Any) -> None: |
| 67 | + if typing_objects.is_forwardref(hint) or isinstance(hint, str): |
| 68 | + self.visit_forward_hint(hint) |
| 69 | + |
| 70 | + def visit_forward_hint(self, hint: ForwardRef | str) -> None: |
| 71 | + raise UnevaluatedTypeHint(hint) |
| 72 | + |
| 73 | + |
| 74 | +# Backport of `typing._should_unflatten_callable_args`: |
| 75 | +def _should_unflatten_callable_args(alias: types.GenericAlias, args: tuple[Any, ...]) -> bool: |
| 76 | + return ( |
| 77 | + alias.__origin__ is collections.abc.Callable # pyright: ignore |
| 78 | + and not (len(args) == 2 and _is_param_expr(args[0])) |
| 79 | + ) |
| 80 | + |
| 81 | + |
| 82 | +class TypeHintTransformer: |
| 83 | + |
| 84 | + def visit(self, hint: Any) -> Any: |
| 85 | + if typing_objects.is_paramspecargs(hint) or typing_objects.is_paramspeckwargs(hint): |
| 86 | + return self.visit_bare_hint(hint) |
| 87 | + origin = get_origin(hint) |
| 88 | + if typing_objects.is_generic(origin): |
| 89 | + # `get_origin()` returns `Generic` if `hint` is `typing.Generic` (or `Generic[...]). |
| 90 | + raise ValueError(f'{hint} is invalid in an annotation expression') |
| 91 | + |
| 92 | + if origin is not None: |
| 93 | + if hint in typing_objects.DEPRECATED_ALIASES: |
| 94 | + # For *bare* deprecated aliases (such as `typing.List`), `get_origin()` returns the |
| 95 | + # actual type (such as `list`). As such, we treat `hint` as a constant. |
| 96 | + return self.visit_bare_hint(hint) |
| 97 | + elif sys.version_info >= (3, 10) and origin is types.UnionType: |
| 98 | + return self.visit_union(hint) |
| 99 | + else: |
| 100 | + return self.visit_generic_alias(hint, origin) |
| 101 | + else: |
| 102 | + return self.visit_bare_hint(hint) |
| 103 | + |
| 104 | + def visit_generic_alias(self, hint: GenericAliasProto, origin: Any) -> Any: |
| 105 | + if typing_objects.is_literal(origin): |
| 106 | + return hint |
| 107 | + |
| 108 | + visited_args = tuple(self.visit(arg) for arg in hint.__args__) |
| 109 | + if visited_args == hint.__args__: |
| 110 | + return hint |
| 111 | + |
| 112 | + if isinstance(hint, types.GenericAlias): |
| 113 | + # Logic from `typing._eval_type()`: |
| 114 | + is_unpacked = hint.__unpacked__ |
| 115 | + if _should_unflatten_callable_args(hint, visited_args): |
| 116 | + t = hint.__origin__[(visited_args[:-1], visited_args[-1])] |
| 117 | + else: |
| 118 | + t = hint.__origin__[visited_args] |
| 119 | + if is_unpacked: |
| 120 | + t = Unpack[t] |
| 121 | + return t |
| 122 | + else: |
| 123 | + # `.copy_with()` is a method present on the private `typing._GenericAlias` class. |
| 124 | + # Many generic aliases (e.g. `Concatenate[]`) have special logic in this method, |
| 125 | + # so we can't just do `hint.__origin__[transformed_args]`. |
| 126 | + return hint.copy_with(visited_args) # pyright: ignore |
| 127 | + |
| 128 | + if sys.version_info >= (3, 10): |
| 129 | + def visit_union(self, hint: types.UnionType) -> Any: |
| 130 | + visited_args = tuple(self.visit(arg) for arg in hint.__args__) |
| 131 | + if visited_args == hint.__args__: |
| 132 | + return hint |
| 133 | + return functools.reduce(operator.or_, visited_args) |
| 134 | + |
| 135 | + def visit_bare_hint(self, hint: Any) -> Any: |
| 136 | + if typing_objects.is_forwardref(hint) or isinstance(hint, str): |
| 137 | + return self.visit_forward_hint(hint) |
| 138 | + else: |
| 139 | + return hint |
| 140 | + |
| 141 | + def visit_forward_hint(self, hint: ForwardRef | str) -> Any: |
| 142 | + raise UnevaluatedTypeHint(hint) |
| 143 | + |
| 144 | + |
| 145 | +class MultiTransformer(TypeHintTransformer): |
| 146 | + def __init__( |
| 147 | + self, |
| 148 | + unpack_type_aliases: Literal['skip', 'lenient', 'eager'] = 'skip', |
| 149 | + type_replacements: dict[Any, Any] = {}, |
| 150 | + ) -> None: |
| 151 | + self.unpack_type_aliases: Literal['skip', 'lenient', 'eager'] = unpack_type_aliases |
| 152 | + self.type_replacements = type_replacements |
| 153 | + |
| 154 | + def visit_generic_alias(self, hint: GenericAliasProto, origin: Any) -> Any: |
| 155 | + args = hint.__args__ |
| 156 | + if self.unpack_type_aliases != 'skip' and typing_objects.is_typealiastype(origin): |
| 157 | + try: |
| 158 | + value = origin.__value__ |
| 159 | + except NameError: |
| 160 | + if self.unpack_type_aliases == 'eager': |
| 161 | + raise |
| 162 | + else: |
| 163 | + return self.visit(value[tuple(self.visit(arg) for arg in args)]) |
| 164 | + return super().visit_generic_alias(hint, origin) |
| 165 | + |
| 166 | + |
| 167 | + def visit_bare_hint(self, hint: Any) -> Any: |
| 168 | + hint = super().visit_bare_hint(hint) |
| 169 | + new_hint = self.type_replacements.get(hint, hint) |
| 170 | + if self.unpack_type_aliases != 'skip' and typing_objects.is_typealiastype(new_hint): |
| 171 | + try: |
| 172 | + value = new_hint.__value__ |
| 173 | + except NameError: |
| 174 | + if self.unpack_type_aliases == 'eager': |
| 175 | + raise |
| 176 | + else: |
| 177 | + return self.visit(value) |
| 178 | + return new_hint |
| 179 | + |
| 180 | + |
| 181 | +def transform_hint( |
| 182 | + hint: Any, |
| 183 | + unpack_type_aliases: Literal['skip', 'lenient', 'eager'] = 'skip', |
| 184 | + type_replacements: dict[Any, Any] = {}, |
| 185 | +) -> Any: |
| 186 | + ... |
0 commit comments