Skip to content

Commit

Permalink
Fix signature of Series.map() (#942)
Browse files Browse the repository at this point in the history
* Fix signature of Series.map()

* Add tests for Series.map() hints.
  • Loading branch information
JanEricNitschke authored Jun 27, 2024
1 parent 56ddb6f commit 54a763c
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 1 deletion.
19 changes: 19 additions & 0 deletions pandas-stubs/_typing.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,25 @@ S1 = TypeVar(
| BaseOffset,
)

S2 = TypeVar(
"S2",
bound=str
| bytes
| datetime.date
| datetime.time
| bool
| int
| float
| complex
| Dtype
| datetime.datetime # includes pd.Timestamp
| datetime.timedelta # includes pd.Timedelta
| Period
| Interval
| CategoricalDtype
| BaseOffset,
)

IndexingInt: TypeAlias = (
int | np.int_ | np.integer | np.unsignedinteger | np.signedinteger | np.int8
)
Expand Down
14 changes: 13 additions & 1 deletion pandas-stubs/core/series.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ from pandas._libs.tslibs import BaseOffset
from pandas._libs.tslibs.nattype import NaTType
from pandas._typing import (
S1,
S2,
AggFuncTypeBase,
AggFuncTypeDictFrame,
AggFuncTypeSeriesToFrame,
Expand Down Expand Up @@ -913,7 +914,18 @@ class Series(IndexOpsMixin[S1], NDFrame):
level: Level = ...,
fill_value: int | _str | dict | None = ...,
) -> DataFrame: ...
def map(self, arg, na_action: Literal["ignore"] | None = ...) -> Series[S1]: ...
@overload
def map(
self,
arg: Callable[[S1], S2 | NAType] | Mapping[S1, S2] | Series[S2],
na_action: Literal["ignore"] = ...,
) -> Series[S2]: ...
@overload
def map(
self,
arg: Callable[[S1 | NAType], S2 | NAType] | Mapping[S1, S2] | Series[S2],
na_action: None = ...,
) -> Series[S2]: ...
@overload
def aggregate( # type: ignore[overload-overlap]
self: Series[int],
Expand Down
44 changes: 44 additions & 0 deletions tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -3230,3 +3230,47 @@ def test_operator_constistency() -> None:
pd.Series,
pd.Timedelta,
)


def test_map() -> None:
s = pd.Series([1, 2, 3])

mapping = {1: "a", 2: "b", 3: "c"}
check(
assert_type(s.map(mapping, na_action="ignore"), "pd.Series[str]"),
pd.Series,
str,
)

def callable(x: int) -> str:
return str(x)

check(
assert_type(s.map(callable, na_action="ignore"), "pd.Series[str]"),
pd.Series,
str,
)

series = pd.Series(["a", "b", "c"])
check(
assert_type(s.map(series, na_action="ignore"), "pd.Series[str]"), pd.Series, str
)


def test_map_na() -> None:
s: pd.Series[int] = pd.Series([1, pd.NA, 3])

mapping = {1: "a", 2: "b", 3: "c"}
check(assert_type(s.map(mapping, na_action=None), "pd.Series[str]"), pd.Series, str)

def callable(x: int | NAType) -> str | NAType:
if isinstance(x, int):
return str(x)
return x

check(
assert_type(s.map(callable, na_action=None), "pd.Series[str]"), pd.Series, str
)

series = pd.Series(["a", "b", "c"])
check(assert_type(s.map(series, na_action=None), "pd.Series[str]"), pd.Series, str)

0 comments on commit 54a763c

Please sign in to comment.