Skip to content

Commit 4704751

Browse files
authored
FEAT-modin-project#6934: Support 'include_groups=False' parameter in 'groupby.apply()' (modin-project#6938)
Signed-off-by: Dmitry Chigarev <dmitry.chigarev@intel.com>
1 parent b875991 commit 4704751

File tree

3 files changed

+30
-17
lines changed

3 files changed

+30
-17
lines changed

modin/core/dataframe/pandas/dataframe/dataframe.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -3967,12 +3967,7 @@ def apply_func(df): # pragma: no cover
39673967
df = df.squeeze(axis=1)
39683968
result = operator(df.groupby(by, **kwargs))
39693969

3970-
if (
3971-
align_result_columns
3972-
and df.empty
3973-
and result.empty
3974-
and df.columns.equals(result.columns)
3975-
):
3970+
if align_result_columns and df.empty and result.empty:
39763971
# We want to align columns only of those frames that actually performed
39773972
# some groupby aggregation, if an empty frame was originally passed
39783973
# (an empty bin on reshuffling was created) then there were no groupby

modin/pandas/groupby.py

+1-11
Original file line numberDiff line numberDiff line change
@@ -646,16 +646,6 @@ def cummax(self, axis=lib.no_default, numeric_only=False, **kwargs):
646646
)
647647

648648
def apply(self, func, *args, include_groups=True, **kwargs):
649-
if not include_groups:
650-
return self._default_to_pandas(
651-
lambda df: df.apply(
652-
func,
653-
*args,
654-
include_groups=include_groups,
655-
**kwargs,
656-
)
657-
)
658-
659649
func = cast_function_modin2pandas(func)
660650
if not isinstance(func, BuiltinFunctionType):
661651
func = wrap_udf_function(func)
@@ -665,7 +655,7 @@ def apply(self, func, *args, include_groups=True, **kwargs):
665655
numeric_only=False,
666656
agg_func=func,
667657
agg_args=args,
668-
agg_kwargs=kwargs,
658+
agg_kwargs={**kwargs, "include_groups": include_groups},
669659
how="group_wise",
670660
)
671661
reduced_index = pandas.Index([MODIN_UNNAMED_SERIES_LABEL])

modin/pandas/test/test_groupby.py

+28
Original file line numberDiff line numberDiff line change
@@ -3358,3 +3358,31 @@ def test_range_groupby_categories_external_grouper(columns, cat_cols):
33583358
pd_df, pd_by = get_external_groupers(pd_df, columns, drop_from_original_df=True)
33593359

33603360
eval_general(md_df.groupby(md_by), pd_df.groupby(pd_by), lambda grp: grp.count())
3361+
3362+
3363+
@pytest.mark.parametrize("by", [["a"], ["a", "b"]])
3364+
@pytest.mark.parametrize("as_index", [True, False])
3365+
@pytest.mark.parametrize("include_groups", [True, False])
3366+
def test_include_groups(by, as_index, include_groups):
3367+
data = {
3368+
"a": [1, 1, 2, 2] * 64,
3369+
"b": [11, 11, 22, 22] * 64,
3370+
"c": [111, 111, 222, 222] * 64,
3371+
"data": [1, 2, 3, 4] * 64,
3372+
}
3373+
3374+
def func(df):
3375+
if include_groups:
3376+
assert len(df.columns.intersection(by)) == len(by)
3377+
else:
3378+
assert len(df.columns.intersection(by)) == 0
3379+
return df.sum()
3380+
3381+
md_df, pd_df = create_test_dfs(data)
3382+
eval_general(
3383+
md_df,
3384+
pd_df,
3385+
lambda df: df.groupby(by, as_index=as_index).apply(
3386+
func, include_groups=include_groups
3387+
),
3388+
)

0 commit comments

Comments
 (0)