Skip to content

Commit 06818d6

Browse files
committed
BinGrouper: Support setting labels when provided with IntervalIndex
Removes a pandas limitation that we don't need.
1 parent 2f1751d commit 06818d6

File tree

2 files changed

+24
-6
lines changed

2 files changed

+24
-6
lines changed

xarray/groupers.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ class BinGrouper(Grouper):
319319
the resulting bins. If False, returns only integer indicators of the
320320
bins. This affects the type of the output container (see below).
321321
This argument is ignored when `bins` is an IntervalIndex. If True,
322-
raises an error. When `ordered=False`, labels must be provided.
322+
raises an error.
323323
retbins : bool, default False
324324
Whether to return the bins or not. Useful when bins is provided
325325
as a scalar.
@@ -394,17 +394,19 @@ def factorize(self, group: T_Group) -> EncodedGroups:
394394

395395
# This seems silly, but it lets us have Pandas handle the complexity
396396
# of `labels`, `precision`, and `include_lowest`, even when group is a chunked array
397-
dummy, _ = self._cut(np.array([0]).astype(group.dtype))
398-
full_index = dummy.categories
397+
if self.labels is None:
398+
dummy, _ = self._cut(np.array([0]).astype(group.dtype))
399+
full_index = dummy.categories
400+
else:
401+
full_index = pd.CategoricalIndex(self.labels)
402+
399403
if not by_is_chunked:
400404
uniques = np.sort(pd.unique(codes.data.ravel()))
401405
unique_values = full_index[uniques[uniques != -1]]
402406
else:
403407
unique_values = full_index
404408

405-
unique_coord = Variable(
406-
dims=new_dim_name, data=unique_values, attrs=group.attrs
407-
)
409+
unique_coord = Variable(dims=self.name, data=unique_values, attrs=group.attrs)
408410
return EncodedGroups(
409411
codes=codes,
410412
full_index=full_index,

xarray/tests/test_groupby.py

+16
Original file line numberDiff line numberDiff line change
@@ -1062,6 +1062,22 @@ def test_groupby_bins_cut_kwargs(use_flox: bool) -> None:
10621062
).mean()
10631063
assert_identical(expected, actual)
10641064

1065+
with xr.set_options(use_flox=use_flox):
1066+
bins_index = pd.IntervalIndex.from_breaks(x_bins)
1067+
labels = ["one", "two", "three"]
1068+
actual = da.groupby(x=BinGrouper(bins=bins_index, labels=labels)).sum()
1069+
assert actual.xindexes["x_bins"].index.equals(pd.CategoricalIndex(labels))
1070+
1071+
1072+
def test_groupby_bins_name_kwarg() -> None:
1073+
da = xr.DataArray(np.arange(12).reshape(6, 2), dims=("x", "y"))
1074+
x_bins = (0, 2, 4, 6)
1075+
actual = da.groupby_bins("x", bins=x_bins, name="foo").sum()
1076+
assert "foo" in actual.dims
1077+
1078+
actual = da.groupby(x=BinGrouper(bins=x_bins, name="foo")).sum()
1079+
assert "foo" in actual.dims
1080+
10651081

10661082
@pytest.mark.parametrize("indexed_coord", [True, False])
10671083
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)