Skip to content

Commit 04872a9

Browse files
committed
BinGrouper: Support setting labels when provided with IntervalIndex
Removes a pandas limitation that we don't need.
1 parent 2f1751d commit 04872a9

File tree

2 files changed

+14
-3
lines changed

2 files changed

+14
-3
lines changed

xarray/groupers.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ class BinGrouper(Grouper):
319319
the resulting bins. If False, returns only integer indicators of the
320320
bins. This affects the type of the output container (see below).
321321
This argument is ignored when `bins` is an IntervalIndex. If True,
322-
raises an error. When `ordered=False`, labels must be provided.
322+
raises an error.
323323
retbins : bool, default False
324324
Whether to return the bins or not. Useful when bins is provided
325325
as a scalar.
@@ -394,8 +394,13 @@ def factorize(self, group: T_Group) -> EncodedGroups:
394394

395395
# This seems silly, but it lets us have Pandas handle the complexity
396396
# of `labels`, `precision`, and `include_lowest`, even when group is a chunked array
397-
dummy, _ = self._cut(np.array([0]).astype(group.dtype))
398-
full_index = dummy.categories
397+
# Pandas ignores labels when IntervalIndex is passed
398+
if not isinstance(self.bins, pd.IntervalIndex):
399+
dummy, _ = self._cut(np.array([0]).astype(group.dtype))
400+
full_index = dummy.categories
401+
else:
402+
full_index = pd.Index(self.labels)
403+
399404
if not by_is_chunked:
400405
uniques = np.sort(pd.unique(codes.data.ravel()))
401406
unique_values = full_index[uniques[uniques != -1]]

xarray/tests/test_groupby.py

+6
Original file line numberDiff line numberDiff line change
@@ -1062,6 +1062,12 @@ def test_groupby_bins_cut_kwargs(use_flox: bool) -> None:
10621062
).mean()
10631063
assert_identical(expected, actual)
10641064

1065+
with xr.set_options(use_flox=use_flox):
1066+
bins_index = pd.IntervalIndex.from_breaks(x_bins)
1067+
labels = ["one", "two", "three"]
1068+
actual = da.groupby(x=BinGrouper(bins=bins_index, labels=labels)).sum()
1069+
assert actual.xindexes["x_bins"].index.equals(pd.Index(labels))
1070+
10651071

10661072
@pytest.mark.parametrize("indexed_coord", [True, False])
10671073
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)