Skip to content

Commit 68233f6

Browse files
authored
Include walrus assignments in conditional inference (#19038)
Fixes #19036.
1 parent 057508b commit 68233f6

File tree

2 files changed

+113
-2
lines changed

2 files changed

+113
-2
lines changed

mypy/checker.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6513,7 +6513,7 @@ def refine_parent_types(self, expr: Expression, expr_type: Type) -> Mapping[Expr
65136513
# and create function that will try replaying the same lookup
65146514
# operation against arbitrary types.
65156515
if isinstance(expr, MemberExpr):
6516-
parent_expr = collapse_walrus(expr.expr)
6516+
parent_expr = self._propagate_walrus_assignments(expr.expr, output)
65176517
parent_type = self.lookup_type_or_none(parent_expr)
65186518
member_name = expr.name
65196519

@@ -6536,9 +6536,10 @@ def replay_lookup(new_parent_type: ProperType) -> Type | None:
65366536
return member_type
65376537

65386538
elif isinstance(expr, IndexExpr):
6539-
parent_expr = collapse_walrus(expr.base)
6539+
parent_expr = self._propagate_walrus_assignments(expr.base, output)
65406540
parent_type = self.lookup_type_or_none(parent_expr)
65416541

6542+
self._propagate_walrus_assignments(expr.index, output)
65426543
index_type = self.lookup_type_or_none(expr.index)
65436544
if index_type is None:
65446545
return output
@@ -6612,6 +6613,24 @@ def replay_lookup(new_parent_type: ProperType) -> Type | None:
66126613
expr = parent_expr
66136614
expr_type = output[parent_expr] = make_simplified_union(new_parent_types)
66146615

6616+
def _propagate_walrus_assignments(
6617+
self, expr: Expression, type_map: dict[Expression, Type]
6618+
) -> Expression:
6619+
"""Add assignments from walrus expressions to inferred types.
6620+
6621+
Only considers nested assignment exprs, does not recurse into other types.
6622+
This may be added later if necessary by implementing a dedicated visitor.
6623+
"""
6624+
if isinstance(expr, AssignmentExpr):
6625+
if isinstance(expr.value, AssignmentExpr):
6626+
self._propagate_walrus_assignments(expr.value, type_map)
6627+
assigned_type = self.lookup_type_or_none(expr.value)
6628+
parent_expr = collapse_walrus(expr)
6629+
if assigned_type is not None:
6630+
type_map[parent_expr] = assigned_type
6631+
return parent_expr
6632+
return expr
6633+
66156634
def refine_identity_comparison_expression(
66166635
self,
66176636
operands: list[Expression],

test-data/unit/check-inference.test

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3979,3 +3979,95 @@ def check(mapping: Mapping[str, _T]) -> None:
39793979
reveal_type(ok1) # N: Revealed type is "Union[_T`-1, builtins.str]"
39803980
ok2: Union[_T, str] = mapping.get("", "")
39813981
[builtins fixtures/tuple.pyi]
3982+
3983+
[case testInferWalrusAssignmentAttrInCondition]
3984+
class Foo:
3985+
def __init__(self, value: bool) -> None:
3986+
self.value = value
3987+
3988+
def check_and(maybe: bool) -> None:
3989+
foo = None
3990+
if maybe and (foo := Foo(True)).value:
3991+
reveal_type(foo) # N: Revealed type is "__main__.Foo"
3992+
else:
3993+
reveal_type(foo) # N: Revealed type is "Union[__main__.Foo, None]"
3994+
3995+
def check_and_nested(maybe: bool) -> None:
3996+
foo = None
3997+
bar = None
3998+
baz = None
3999+
if maybe and (foo := (bar := (baz := Foo(True)))).value:
4000+
reveal_type(foo) # N: Revealed type is "__main__.Foo"
4001+
reveal_type(bar) # N: Revealed type is "__main__.Foo"
4002+
reveal_type(baz) # N: Revealed type is "__main__.Foo"
4003+
else:
4004+
reveal_type(foo) # N: Revealed type is "Union[__main__.Foo, None]"
4005+
reveal_type(bar) # N: Revealed type is "Union[__main__.Foo, None]"
4006+
reveal_type(baz) # N: Revealed type is "Union[__main__.Foo, None]"
4007+
4008+
def check_or(maybe: bool) -> None:
4009+
foo = None
4010+
if maybe or (foo := Foo(True)).value:
4011+
reveal_type(foo) # N: Revealed type is "Union[__main__.Foo, None]"
4012+
else:
4013+
reveal_type(foo) # N: Revealed type is "__main__.Foo"
4014+
4015+
def check_or_nested(maybe: bool) -> None:
4016+
foo = None
4017+
bar = None
4018+
baz = None
4019+
if maybe and (foo := (bar := (baz := Foo(True)))).value:
4020+
reveal_type(foo) # N: Revealed type is "__main__.Foo"
4021+
reveal_type(bar) # N: Revealed type is "__main__.Foo"
4022+
reveal_type(baz) # N: Revealed type is "__main__.Foo"
4023+
else:
4024+
reveal_type(foo) # N: Revealed type is "Union[__main__.Foo, None]"
4025+
reveal_type(bar) # N: Revealed type is "Union[__main__.Foo, None]"
4026+
reveal_type(baz) # N: Revealed type is "Union[__main__.Foo, None]"
4027+
4028+
[case testInferWalrusAssignmentIndexInCondition]
4029+
def check_and(maybe: bool) -> None:
4030+
foo = None
4031+
bar = None
4032+
if maybe and (foo := [1])[(bar := 0)]:
4033+
reveal_type(foo) # N: Revealed type is "builtins.list[builtins.int]"
4034+
reveal_type(bar) # N: Revealed type is "builtins.int"
4035+
else:
4036+
reveal_type(foo) # N: Revealed type is "Union[builtins.list[builtins.int], None]"
4037+
reveal_type(bar) # N: Revealed type is "Union[builtins.int, None]"
4038+
4039+
def check_and_nested(maybe: bool) -> None:
4040+
foo = None
4041+
bar = None
4042+
baz = None
4043+
if maybe and (foo := (bar := (baz := [1])))[0]:
4044+
reveal_type(foo) # N: Revealed type is "builtins.list[builtins.int]"
4045+
reveal_type(bar) # N: Revealed type is "builtins.list[builtins.int]"
4046+
reveal_type(baz) # N: Revealed type is "builtins.list[builtins.int]"
4047+
else:
4048+
reveal_type(foo) # N: Revealed type is "Union[builtins.list[builtins.int], None]"
4049+
reveal_type(bar) # N: Revealed type is "Union[builtins.list[builtins.int], None]"
4050+
reveal_type(baz) # N: Revealed type is "Union[builtins.list[builtins.int], None]"
4051+
4052+
def check_or(maybe: bool) -> None:
4053+
foo = None
4054+
bar = None
4055+
if maybe or (foo := [1])[(bar := 0)]:
4056+
reveal_type(foo) # N: Revealed type is "Union[builtins.list[builtins.int], None]"
4057+
reveal_type(bar) # N: Revealed type is "Union[builtins.int, None]"
4058+
else:
4059+
reveal_type(foo) # N: Revealed type is "builtins.list[builtins.int]"
4060+
reveal_type(bar) # N: Revealed type is "builtins.int"
4061+
4062+
def check_or_nested(maybe: bool) -> None:
4063+
foo = None
4064+
bar = None
4065+
baz = None
4066+
if maybe or (foo := (bar := (baz := [1])))[0]:
4067+
reveal_type(foo) # N: Revealed type is "Union[builtins.list[builtins.int], None]"
4068+
reveal_type(bar) # N: Revealed type is "Union[builtins.list[builtins.int], None]"
4069+
reveal_type(baz) # N: Revealed type is "Union[builtins.list[builtins.int], None]"
4070+
else:
4071+
reveal_type(foo) # N: Revealed type is "builtins.list[builtins.int]"
4072+
reveal_type(bar) # N: Revealed type is "builtins.list[builtins.int]"
4073+
reveal_type(baz) # N: Revealed type is "builtins.list[builtins.int]"

0 commit comments

Comments
 (0)