Skip to content

Commit 0ea8488

Browse files
authored
Fix nondeterministic type checking by making join with explicit Protocol and type promotion commute (#18402)
Fixes #16979 (bzoracler case only, OP case fixed by #19147) See #16979 (comment)
1 parent a16521f commit 0ea8488

File tree

3 files changed

+89
-3
lines changed

3 files changed

+89
-3
lines changed

mypy/join.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import mypy.typeops
99
from mypy.expandtype import expand_type
1010
from mypy.maptype import map_instance_to_supertype
11-
from mypy.nodes import CONTRAVARIANT, COVARIANT, INVARIANT, VARIANCE_NOT_READY
11+
from mypy.nodes import CONTRAVARIANT, COVARIANT, INVARIANT, VARIANCE_NOT_READY, TypeInfo
1212
from mypy.state import state
1313
from mypy.subtypes import (
1414
SubtypeContext,
@@ -168,9 +168,20 @@ def join_instances_via_supertype(self, t: Instance, s: Instance) -> ProperType:
168168
# Compute the "best" supertype of t when joined with s.
169169
# The definition of "best" may evolve; for now it is the one with
170170
# the longest MRO. Ties are broken by using the earlier base.
171-
best: ProperType | None = None
171+
172+
# Go over both sets of bases in case there's an explicit Protocol base. This is important
173+
# to ensure commutativity of join (although in cases where both classes have relevant
174+
# Protocol bases this maybe might still not be commutative)
175+
base_types: dict[TypeInfo, None] = {} # dict to deduplicate but preserve order
172176
for base in t.type.bases:
173-
mapped = map_instance_to_supertype(t, base.type)
177+
base_types[base.type] = None
178+
for base in s.type.bases:
179+
if base.type.is_protocol and is_subtype(t, base):
180+
base_types[base.type] = None
181+
182+
best: ProperType | None = None
183+
for base_type in base_types:
184+
mapped = map_instance_to_supertype(t, base_type)
174185
res = self.join_instances(mapped, s)
175186
if best is None or is_better(res, best):
176187
best = res
@@ -662,6 +673,10 @@ def is_better(t: Type, s: Type) -> bool:
662673
if isinstance(t, Instance):
663674
if not isinstance(s, Instance):
664675
return True
676+
if t.type.is_protocol != s.type.is_protocol:
677+
if t.type.fullname != "builtins.object" and s.type.fullname != "builtins.object":
678+
# mro of protocol is not really relevant
679+
return not t.type.is_protocol
665680
# Use len(mro) as a proxy for the better choice.
666681
if len(t.type.mro) > len(s.type.mro):
667682
return True

test-data/unit/check-inference.test

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3888,6 +3888,53 @@ def a4(x: List[str], y: List[Never]) -> None:
38883888
z1[1].append("asdf") # E: "object" has no attribute "append"
38893889
[builtins fixtures/dict.pyi]
38903890

3891+
3892+
[case testDeterminismCommutativityWithJoinInvolvingProtocolBaseAndPromotableType]
3893+
# flags: --python-version 3.11
3894+
# Regression test for https://github.com/python/mypy/issues/16979#issuecomment-1982246306
3895+
from __future__ import annotations
3896+
3897+
from typing import Any, Generic, Protocol, TypeVar, overload, cast
3898+
from typing_extensions import Never
3899+
3900+
T = TypeVar("T")
3901+
U = TypeVar("U")
3902+
3903+
class _SupportsCompare(Protocol):
3904+
def __lt__(self, other: Any, /) -> bool:
3905+
return True
3906+
3907+
class Comparable(_SupportsCompare):
3908+
pass
3909+
3910+
comparable: Comparable = Comparable()
3911+
3912+
from typing import _promote
3913+
3914+
class floatlike:
3915+
def __lt__(self, other: floatlike, /) -> bool: ...
3916+
3917+
@_promote(floatlike)
3918+
class intlike:
3919+
def __lt__(self, other: intlike, /) -> bool: ...
3920+
3921+
3922+
class A(Generic[T, U]):
3923+
@overload
3924+
def __init__(self: A[T, T], a: T, b: T, /) -> None: ... # type: ignore[overload-overlap]
3925+
@overload
3926+
def __init__(self: A[T, U], a: T, b: U, /) -> Never: ...
3927+
def __init__(self, *a) -> None: ...
3928+
3929+
def join(a: T, b: T) -> T: ...
3930+
3931+
reveal_type(join(intlike(), comparable)) # N: Revealed type is "__main__._SupportsCompare"
3932+
reveal_type(join(comparable, intlike())) # N: Revealed type is "__main__._SupportsCompare"
3933+
reveal_type(A(intlike(), comparable)) # N: Revealed type is "__main__.A[__main__._SupportsCompare, __main__._SupportsCompare]"
3934+
reveal_type(A(comparable, intlike())) # N: Revealed type is "__main__.A[__main__._SupportsCompare, __main__._SupportsCompare]"
3935+
[builtins fixtures/tuple.pyi]
3936+
[typing fixtures/typing-medium.pyi]
3937+
38913938
[case testTupleJoinFallbackInference]
38923939
foo = [
38933940
(1, ("a", "b")),

test-data/unit/check-protocols.test

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4461,6 +4461,30 @@ f2(a4) # E: Argument 1 to "f2" has incompatible type "A4"; expected "P2" \
44614461
# N: foo: expected setter type "C1", got "str"
44624462
[builtins fixtures/property.pyi]
44634463

4464+
4465+
[case testExplicitProtocolJoinPreference]
4466+
from typing import Protocol, TypeVar
4467+
4468+
T = TypeVar("T")
4469+
4470+
class Proto1(Protocol):
4471+
def foo(self) -> int: ...
4472+
class Proto2(Proto1):
4473+
def bar(self) -> str: ...
4474+
class Proto3(Proto2):
4475+
def baz(self) -> str: ...
4476+
4477+
class Base: ...
4478+
4479+
class A(Base, Proto3): ...
4480+
class B(Base, Proto3): ...
4481+
4482+
def join(a: T, b: T) -> T: ...
4483+
4484+
def main(a: A, b: B) -> None:
4485+
reveal_type(join(a, b)) # N: Revealed type is "__main__.Proto3"
4486+
reveal_type(join(b, a)) # N: Revealed type is "__main__.Proto3"
4487+
44644488
[case testProtocolImplementationWithDescriptors]
44654489
from typing import Any, Protocol
44664490

0 commit comments

Comments
 (0)