Skip to content

Commit

Permalink
Allow Python Enum input in Polars Enum constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego committed Nov 22, 2024
1 parent d424674 commit 99cbb42
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 3 deletions.
2 changes: 1 addition & 1 deletion py-polars/polars/datatypes/_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def parse_py_type_into_dtype(input: PythonDataType | type[object]) -> PolarsData
elif input is list or input is tuple:
return List
elif isclass(input) and issubclass(input, enum.Enum):
return Enum(input.__members__.values())
return Enum(input)
# this is required as pass through. Don't remove
elif input == Unknown:
return Unknown
Expand Down
7 changes: 5 additions & 2 deletions py-polars/polars/datatypes/classes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import contextlib
import enum
from collections import OrderedDict
from collections.abc import Mapping
from datetime import timezone
Expand Down Expand Up @@ -596,7 +597,7 @@ class Enum(DataType):

categories: Series

def __init__(self, categories: Series | Iterable[str]) -> None:
def __init__(self, categories: Series | Iterable[str] | type[enum.Enum]) -> None:
# Issuing the warning on `__init__` does not trigger when the class is used
# without being instantiated, but it's better than nothing
from polars._utils.unstable import issue_unstable_warning
Expand All @@ -606,7 +607,9 @@ def __init__(self, categories: Series | Iterable[str]) -> None:
" It is a work-in-progress feature and may not always work as expected."
)

if not isinstance(categories, pl.Series):
if isclass(categories) and issubclass(categories, enum.Enum):
categories = pl.Series(values=categories.__members__.values())
elif not isinstance(categories, pl.Series):
categories = pl.Series(values=categories)

if categories.is_empty():
Expand Down
33 changes: 33 additions & 0 deletions py-polars/tests/unit/datatypes/test_enum.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import enum
import operator
import re
import sys
from datetime import date
from textwrap import dedent
from typing import Any, Callable
Expand Down Expand Up @@ -41,6 +43,37 @@ def test_enum_init_empty(categories: pl.Series | list[str] | None) -> None:
assert_series_equal(dtype.categories, expected)


def test_enum_init_python_enum_19724() -> None:
class PythonEnum(str, enum.Enum):
CAT1 = "A"
CAT2 = "B"
CAT3 = "C"

result = pl.Enum(PythonEnum)
assert result == pl.Enum(["A", "B", "C"])


def test_enum_init_python_enum_ints_19724() -> None:
class PythonEnum(int, enum.Enum):
CAT1 = 1
CAT2 = 2
CAT3 = 3

with pytest.raises(TypeError, match="Enum categories must be strings"):
pl.Enum(PythonEnum)


@pytest.mark.skipif(sys.version_info < (3, 11), reason="Requires Python 3.11 or later")
def test_enum_init_python_strenum_19724() -> None:
class PythonEnum(enum.StrEnum):
CAT1 = "A"
CAT2 = "B"
CAT3 = "C"

result = pl.Enum(PythonEnum)
assert result == pl.Enum(["A", "B", "C"])


def test_enum_non_existent() -> None:
with pytest.raises(
InvalidOperationError,
Expand Down

0 comments on commit 99cbb42

Please sign in to comment.