Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): Allow Python Enums as dtype inputs #19926

Merged
merged 3 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions py-polars/polars/datatypes/_parse.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

import enum
import functools
import re
import sys
from datetime import date, datetime, time, timedelta
from decimal import Decimal as PyDecimal
from inspect import isclass
from typing import TYPE_CHECKING, Any, ForwardRef, NoReturn, Union, get_args

from polars.datatypes.classes import (
Expand All @@ -14,6 +16,7 @@
Datetime,
Decimal,
Duration,
Enum,
Float64,
Int64,
List,
Expand Down Expand Up @@ -94,6 +97,8 @@ def parse_py_type_into_dtype(input: PythonDataType | type[object]) -> PolarsData
return Null()
elif input is list or input is tuple:
return List
elif isclass(input) and issubclass(input, enum.Enum):
return Enum(input)
# this is required as pass through. Don't remove
elif input == Unknown:
return Unknown
Expand Down
7 changes: 5 additions & 2 deletions py-polars/polars/datatypes/classes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import contextlib
import enum
from collections import OrderedDict
from collections.abc import Mapping
from datetime import timezone
Expand Down Expand Up @@ -596,7 +597,7 @@ class Enum(DataType):

categories: Series

def __init__(self, categories: Series | Iterable[str]) -> None:
def __init__(self, categories: Series | Iterable[str] | type[enum.Enum]) -> None:
# Issuing the warning on `__init__` does not trigger when the class is used
# without being instantiated, but it's better than nothing
from polars._utils.unstable import issue_unstable_warning
Expand All @@ -606,7 +607,9 @@ def __init__(self, categories: Series | Iterable[str]) -> None:
" It is a work-in-progress feature and may not always work as expected."
)

if not isinstance(categories, pl.Series):
if isclass(categories) and issubclass(categories, enum.Enum):
categories = pl.Series(values=categories.__members__.values())
elif not isinstance(categories, pl.Series):
categories = pl.Series(values=categories)

if categories.is_empty():
Expand Down
11 changes: 11 additions & 0 deletions py-polars/tests/unit/constructors/test_dataframe.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import enum
import sys
from collections import OrderedDict
from collections.abc import Mapping
Expand Down Expand Up @@ -194,3 +195,13 @@ def test_df_init_schema_object() -> None:
def test_df_init_data_orientation_inference_warning() -> None:
with pytest.warns(DataOrientationWarning):
pl.from_records([[1, 2, 3], [4, 5, 6]], schema=["a", "b", "c"])


def test_df_init_enum_dtype() -> None:
class PythonEnum(str, enum.Enum):
A = "A"
B = "B"
C = "C"

df = pl.DataFrame({"Col 1": ["A", "B", "C"]}, schema={"Col 1": PythonEnum})
assert df.dtypes[0] == pl.Enum(["A", "B", "C"])
21 changes: 21 additions & 0 deletions py-polars/tests/unit/datatypes/test_enum.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import enum
import operator
import re
from datetime import date
Expand Down Expand Up @@ -41,6 +42,26 @@ def test_enum_init_empty(categories: pl.Series | list[str] | None) -> None:
assert_series_equal(dtype.categories, expected)


def test_enum_init_python_enum_19724() -> None:
class PythonEnum(str, enum.Enum):
CAT1 = "A"
CAT2 = "B"
CAT3 = "C"

result = pl.Enum(PythonEnum)
assert result == pl.Enum(["A", "B", "C"])


def test_enum_init_python_enum_ints_19724() -> None:
class PythonEnum(int, enum.Enum):
CAT1 = 1
CAT2 = 2
CAT3 = 3

with pytest.raises(TypeError, match="Enum categories must be strings"):
pl.Enum(PythonEnum)


def test_enum_non_existent() -> None:
with pytest.raises(
InvalidOperationError,
Expand Down
22 changes: 22 additions & 0 deletions py-polars/tests/unit/datatypes/test_parse.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import enum
from datetime import date, datetime
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -43,6 +44,27 @@ def test_parse_into_dtype(input: Any, expected: PolarsDataType) -> None:
assert_dtype_equal(result, expected)


def test_parse_into_dtype_enum_19724() -> None:
class PythonEnum(str, enum.Enum):
CAT1 = "A"
CAT2 = "B"
CAT3 = "C"

result = parse_into_dtype(PythonEnum)
expected = pl.Enum(["A", "B", "C"])
assert_dtype_equal(result, expected)


def test_parse_into_dtype_enum_ints_19724() -> None:
class PythonEnum(int, enum.Enum):
CAT1 = 1
CAT2 = 2
CAT3 = 3

with pytest.raises(TypeError, match="Enum categories must be strings"):
parse_into_dtype(PythonEnum)


@pytest.mark.parametrize(
("input", "expected"),
[
Expand Down