Skip to content

Commit 1c61ea9

Browse files
committed
Add type hint parsing utilities
1 parent e20451f commit 1c61ea9

File tree

4 files changed

+328
-0
lines changed

4 files changed

+328
-0
lines changed

src/typing_inspection/introspection/__init__.py

Whitespace-only changes.
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
import types
2+
import sys
3+
import functools
4+
import operator
5+
import collections.abc
6+
from typing import Any, ForwardRef, Literal
7+
8+
from typing_extensions import Unpack, get_origin
9+
10+
from ._types import GenericAliasProto
11+
from ._utils import _is_param_expr
12+
from typing_inspection import typing_objects
13+
14+
15+
16+
class UnevaluatedTypeHint(Exception):
17+
"""The type hint wasn't evaluated as it still contains forward references."""
18+
19+
forward_arg: ForwardRef | str
20+
"""The forward reference that wasn't evaluated."""
21+
22+
def __init__(self, forward_arg: ForwardRef | str) -> None:
23+
self.forward_arg = forward_arg
24+
25+
class TypeHintVisitor:
26+
27+
def visit(self, hint: Any) -> None:
28+
if typing_objects.is_paramspecargs(hint) or typing_objects.is_paramspeckwargs(hint):
29+
return self.visit_bare_hint(hint)
30+
origin = get_origin(hint)
31+
if typing_objects.is_generic(origin):
32+
# `get_origin()` returns `Generic` if `hint` is `typing.Generic` (or `Generic[...]`).
33+
raise ValueError(f'{hint} is invalid in an annotation expression')
34+
35+
if origin is not None:
36+
if hint in typing_objects.DEPRECATED_ALIASES:
37+
# For *bare* deprecated aliases (such as `typing.List`), `get_origin()` returns the
38+
# actual type (such as `list`). As such, we treat `hint` as a bare hint.
39+
self.visit_bare_hint(hint)
40+
elif sys.version_info >= (3, 10) and origin is types.UnionType:
41+
self.visit_union(hint)
42+
else:
43+
self.visit_generic_alias(hint, origin)
44+
else:
45+
self.visit_bare_hint(hint)
46+
47+
# origin = get_origin(hint)
48+
# if origin in DEPRECATED_ALIASES.values() and not isinstance(hint, types.GenericAlias):
49+
# # hint is a deprecated generic alias, e.g. `List[int]`.
50+
# # `get_origin(List[int])` returns `list`, but we want to preserve
51+
# # `List` as the actual origin.
52+
53+
def visit_generic_alias(self, hint: GenericAliasProto, origin: Any) -> None:
54+
if not typing_objects.is_literal(origin):
55+
# Note: it is important to use `hint.__args__` instead of `get_args()` as
56+
# they differ for some typing forms (e.g. `Annotated`, `Callable`).
57+
# `hint.__args__` should be guaranteed to only contain other annotation expressions.
58+
for arg in hint.__args__:
59+
self.visit(arg)
60+
61+
if sys.version_info >= (3, 10):
62+
def visit_union(self, hint: types.UnionType) -> None:
63+
for arg in hint.__args__:
64+
self.visit(arg)
65+
66+
def visit_bare_hint(self, hint: Any) -> None:
67+
if typing_objects.is_forwardref(hint) or isinstance(hint, str):
68+
self.visit_forward_hint(hint)
69+
70+
def visit_forward_hint(self, hint: ForwardRef | str) -> None:
71+
raise UnevaluatedTypeHint(hint)
72+
73+
74+
# Backport of `typing._should_unflatten_callable_args`:
75+
def _should_unflatten_callable_args(alias: types.GenericAlias, args: tuple[Any, ...]) -> bool:
76+
return (
77+
alias.__origin__ is collections.abc.Callable # pyright: ignore
78+
and not (len(args) == 2 and _is_param_expr(args[0]))
79+
)
80+
81+
82+
class TypeHintTransformer:
83+
84+
def visit(self, hint: Any) -> Any:
85+
if typing_objects.is_paramspecargs(hint) or typing_objects.is_paramspeckwargs(hint):
86+
return self.visit_bare_hint(hint)
87+
origin = get_origin(hint)
88+
if typing_objects.is_generic(origin):
89+
# `get_origin()` returns `Generic` if `hint` is `typing.Generic` (or `Generic[...]).
90+
raise ValueError(f'{hint} is invalid in an annotation expression')
91+
92+
if origin is not None:
93+
if hint in typing_objects.DEPRECATED_ALIASES:
94+
# For *bare* deprecated aliases (such as `typing.List`), `get_origin()` returns the
95+
# actual type (such as `list`). As such, we treat `hint` as a constant.
96+
return self.visit_bare_hint(hint)
97+
elif sys.version_info >= (3, 10) and origin is types.UnionType:
98+
return self.visit_union(hint)
99+
else:
100+
return self.visit_generic_alias(hint, origin)
101+
else:
102+
return self.visit_bare_hint(hint)
103+
104+
def visit_generic_alias(self, hint: GenericAliasProto, origin: Any) -> Any:
105+
if typing_objects.is_literal(origin):
106+
return hint
107+
108+
visited_args = tuple(self.visit(arg) for arg in hint.__args__)
109+
if visited_args == hint.__args__:
110+
return hint
111+
112+
if isinstance(hint, types.GenericAlias):
113+
# Logic from `typing._eval_type()`:
114+
is_unpacked = hint.__unpacked__
115+
if _should_unflatten_callable_args(hint, visited_args):
116+
t = hint.__origin__[(visited_args[:-1], visited_args[-1])]
117+
else:
118+
t = hint.__origin__[visited_args]
119+
if is_unpacked:
120+
t = Unpack[t]
121+
return t
122+
else:
123+
# `.copy_with()` is a method present on the private `typing._GenericAlias` class.
124+
# Many generic aliases (e.g. `Concatenate[]`) have special logic in this method,
125+
# so we can't just do `hint.__origin__[transformed_args]`.
126+
return hint.copy_with(visited_args) # pyright: ignore
127+
128+
if sys.version_info >= (3, 10):
129+
def visit_union(self, hint: types.UnionType) -> Any:
130+
visited_args = tuple(self.visit(arg) for arg in hint.__args__)
131+
if visited_args == hint.__args__:
132+
return hint
133+
return functools.reduce(operator.or_, visited_args)
134+
135+
def visit_bare_hint(self, hint: Any) -> Any:
136+
if typing_objects.is_forwardref(hint) or isinstance(hint, str):
137+
return self.visit_forward_hint(hint)
138+
else:
139+
return hint
140+
141+
def visit_forward_hint(self, hint: ForwardRef | str) -> Any:
142+
raise UnevaluatedTypeHint(hint)
143+
144+
145+
class MultiTransformer(TypeHintTransformer):
146+
def __init__(
147+
self,
148+
unpack_type_aliases: Literal['skip', 'lenient', 'eager'] = 'skip',
149+
type_replacements: dict[Any, Any] = {},
150+
) -> None:
151+
self.unpack_type_aliases: Literal['skip', 'lenient', 'eager'] = unpack_type_aliases
152+
self.type_replacements = type_replacements
153+
154+
def visit_generic_alias(self, hint: GenericAliasProto, origin: Any) -> Any:
155+
args = hint.__args__
156+
if self.unpack_type_aliases != 'skip' and typing_objects.is_typealiastype(origin):
157+
try:
158+
value = origin.__value__
159+
except NameError:
160+
if self.unpack_type_aliases == 'eager':
161+
raise
162+
else:
163+
return self.visit(value[tuple(self.visit(arg) for arg in args)])
164+
return super().visit_generic_alias(hint, origin)
165+
166+
167+
def visit_bare_hint(self, hint: Any) -> Any:
168+
hint = super().visit_bare_hint(hint)
169+
new_hint = self.type_replacements.get(hint, hint)
170+
if self.unpack_type_aliases != 'skip' and typing_objects.is_typealiastype(new_hint):
171+
try:
172+
value = new_hint.__value__
173+
except NameError:
174+
if self.unpack_type_aliases == 'eager':
175+
raise
176+
else:
177+
return self.visit(value)
178+
return new_hint
179+
180+
181+
def transform_hint(
182+
hint: Any,
183+
unpack_type_aliases: Literal['skip', 'lenient', 'eager'] = 'skip',
184+
type_replacements: dict[Any, Any] = {},
185+
) -> Any:
186+
...
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from typing import Any, Protocol
2+
3+
from typing_extensions import TypeVar, TypeAlias, ParamSpec, TypeVarTuple
4+
5+
OriginT = TypeVar('OriginT', default=Any)
6+
7+
class GenericAliasProto(Protocol[OriginT]):
8+
"""An instance of a parameterized [generic type][] or typing form.
9+
10+
Depending on the alias, this may be an instance of [`types.GenericAlias`][]
11+
(e.g. `list[int]`) or a private `typing` class (`typing._GenericAlias`).
12+
"""
13+
__origin__: OriginT
14+
__args__: tuple[Any, ...]
15+
__parameters__: tuple[Any, ...]
16+
17+
18+
TypeVarLike: TypeAlias = 'TypeVar | TypeVarTuple | ParamSpec'
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import sys
2+
3+
from typing import Any
4+
5+
from ._types import GenericAliasProto, TypeVarLike
6+
7+
from typing_inspection import typing_objects
8+
9+
from typing_extensions import NoDefault, ParamSpec, get_origin
10+
11+
def get_default(t: TypeVarLike, /) -> Any:
12+
"""Get the default value of a type parameter, if it exists.
13+
14+
Args:
15+
t: The [`TypeVar`][typing.TypeVar], [`TypeVarTuple`][typing.TypeVarTuple] or
16+
[`ParamSpec`][typing.ParamSpec] instance to get the default from.
17+
18+
Returns:
19+
The default value, or [`typing.NoDefault`] if not default is set.
20+
!!! warning
21+
This function may return the [`NoDefault` backport][typing_extensions.NoDefault] backport
22+
from `typing_extensions`. As such, [`typing_objects.is_nodefault()`][typing_inspection.typing_objects.is_nodefault]
23+
should be used.
24+
"""
25+
26+
try:
27+
has_default = t.has_default()
28+
except AttributeError:
29+
return NoDefault
30+
else:
31+
if has_default:
32+
return t.__default__
33+
else:
34+
return NoDefault
35+
36+
37+
def alias_substitutions(alias: GenericAliasProto, /) -> dict[TypeVarLike, Any]:
38+
params: tuple[TypeVarLike, ...] | None = getattr(alias.__origin__, '__parameters__', None)
39+
if params is None:
40+
raise ValueError
41+
42+
origin = alias.__origin__
43+
args = alias.__args__
44+
45+
# TODO checks for invalid params (most of the checks are already performed
46+
# by Python for generic classes, but aren't for type aliases)
47+
...
48+
49+
if typing_objects.is_typealiastype(origin) and len(params) == 1 and typing_objects.is_paramspec(params[0]):
50+
# The end of the documentation section at
51+
# https://docs.python.org/3/library/typing.html#user-defined-generic-types
52+
# says:
53+
# a generic with only one parameter specification variable will accept parameter
54+
# lists in the forms X[[Type1, Type2, ...]] and also X[Type1, Type2, ...].
55+
# However, this convenience isn't applied for type aliases.
56+
if len(args) == 0:
57+
# Unlike user-defined generics, type aliases don't fallback to the default:
58+
arg = get_default(params[0])
59+
if typing_objects.is_nodefault(arg):
60+
raise ValueError
61+
elif len(args) == 1 and not _is_param_expr(args[0]):
62+
arg = args[0]
63+
64+
if not _is_param_expr(arg):
65+
arg = (arg,)
66+
elif isinstance(arg, list):
67+
arg = tuple(arg)
68+
69+
substitutions: dict[TypeVarLike, Any] = {}
70+
71+
typevartuple_param = next((p for p in params if typing_objects.is_typevartuple(p)), None)
72+
73+
if typevartuple_param is not None:
74+
# HARD
75+
pass
76+
else:
77+
strict = {'strict': True} if sys.version_info >= (3, 10) else {}
78+
return dict(zip(params, args), **strict)
79+
80+
81+
class A[*Ts, T]:
82+
a: tuple[int, *Ts]
83+
84+
def func(self, *args: *Ts): pass
85+
86+
87+
88+
A[str, *tuple[*()]]
89+
90+
A[str, *tuple[int, ...]]().a
91+
92+
93+
A[str, *tuple[int, *tuple[str, ...]]]().a
94+
95+
96+
# Backports of private `typing` functions:
97+
98+
# Backport of `typing._is_param_expr`:
99+
def _is_param_expr(arg: Any) -> bool:
100+
return (
101+
arg is ... # as in `Callable[..., Any]`
102+
or isinstance(arg, (tuple, list)) # as in `Callable[[int, str], Any]`
103+
or typing_objects.is_paramspec(arg) # as in `Callable[P, Any]`
104+
or typing_objects.is_concatenate(get_origin(arg)) # as in `Callable[Concatenate[int, P], Any]`
105+
)
106+
107+
# Backports of the `__typing_prepare_subst__` methods of type parameter classes,
108+
# only available in 3.11+:
109+
110+
def _paramspec_prepare_subst(self: ParamSpec, alias: GenericAliasProto, args: tuple[Any, ...]):
111+
params = alias.__parameters__
112+
i = params.index(self)
113+
if i == len(args) and not typing_objects.is_nodefault((default := get_default(self))):
114+
args = (*args, default)
115+
if i >= len(args):
116+
raise TypeError(f"Too few arguments for {alias}")
117+
# Special case where Z[[int, str, bool]] == Z[int, str, bool] in PEP 612.
118+
if len(params) == 1 and not _is_param_expr(args[0]):
119+
assert i == 0
120+
args = (args,)
121+
# Convert lists to tuples to help other libraries cache the results.
122+
elif isinstance(args[i], list):
123+
args = (*args[:i], tuple[args[i]], *args[i + 1: ])
124+
return args

0 commit comments

Comments
 (0)