diff --git a/py-polars/polars/datatypes/classes.py b/py-polars/polars/datatypes/classes.py index bc60c3d6bad1..4008775072ac 100644 --- a/py-polars/polars/datatypes/classes.py +++ b/py-polars/polars/datatypes/classes.py @@ -582,7 +582,7 @@ def __hash__(self) -> int: class Enum(DataType): """ - A fixed set categorical encoding of a set of strings. + A fixed categorical encoding of a unique set of strings. .. warning:: This functionality is considered **unstable**. @@ -592,8 +592,22 @@ class Enum(DataType): Parameters ---------- categories - The categories in the dataset. Categories must be strings. - """ + The categories in the dataset; must be a unique set of strings, or an + existing Python string-valued enum. + + Examples + -------- + Explicitly define enumeration categories: + + >>> pl.Enum(["north", "south", "east", "west"]) + Enum(categories=['north', 'south', 'east', 'west']) + + Initialise from an existing Python enumeration: + + >>> from http import HTTPMethod + >>> pl.Enum(HTTPMethod) + Enum(categories=['CONNECT', 'DELETE', 'GET', 'HEAD', 'OPTIONS', 'PATCH', 'POST', 'PUT', 'TRACE']) + """ # noqa: W505 categories: Series @@ -608,7 +622,17 @@ def __init__(self, categories: Series | Iterable[str] | type[enum.Enum]) -> None ) if isclass(categories) and issubclass(categories, enum.Enum): - categories = pl.Series(values=categories.__members__.values()) + for enum_subclass in (enum.IntFlag, enum.Flag, enum.IntEnum): + if issubclass(categories, enum_subclass): + enum_type_name = enum_subclass.__name__ + msg = f"Enum categories must be strings; Python `enum.{enum_type_name}` values are integers" + raise TypeError(msg) + + enum_values = [ + (v if isinstance(v, str) else v.value) + for v in categories.__members__.values() + ] + categories = pl.Series(values=enum_values) elif not isinstance(categories, pl.Series): categories = pl.Series(values=categories) diff --git a/py-polars/tests/unit/datatypes/test_enum.py b/py-polars/tests/unit/datatypes/test_enum.py index bc5a9370a222..bd23ddcf623e 100644 --- a/py-polars/tests/unit/datatypes/test_enum.py +++ b/py-polars/tests/unit/datatypes/test_enum.py @@ -3,6 +3,7 @@ import enum import operator import re +import sys from datetime import date from textwrap import dedent from typing import Any, Callable @@ -42,32 +43,70 @@ def test_enum_init_empty(categories: pl.Series | list[str] | None) -> None: assert_series_equal(dtype.categories, expected) -def test_enum_init_python_enum_19724() -> None: - class PythonEnum(str, enum.Enum): - CAT1 = "A" - CAT2 = "B" - CAT3 = "C" +def test_enum_init_from_python() -> None: + # standard string enum + class Color1(str, enum.Enum): + RED = "red" + GREEN = "green" + BLUE = "blue" - result = pl.Enum(PythonEnum) - assert result == pl.Enum(["A", "B", "C"]) + dtype = pl.Enum(Color1) + assert dtype == pl.Enum(["red", "green", "blue"]) + # standard generic enum + class Color2(enum.Enum): + RED = "red" + GREEN = "green" + BLUE = "blue" -def test_enum_init_python_enum_ints_19724() -> None: - class PythonEnum(int, enum.Enum): - CAT1 = 1 - CAT2 = 2 - CAT3 = 3 + dtype = pl.Enum(Color2) + assert dtype == pl.Enum(["red", "green", "blue"]) - with pytest.raises(TypeError, match="Enum categories must be strings"): - pl.Enum(PythonEnum) + # specialised string enum + if sys.version_info >= (3, 11): + + class Color3(enum.Enum): + RED = "red" + GREEN = "green" + BLUE = "blue" + + dtype = pl.Enum(Color3) + assert dtype == pl.Enum(["red", "green", "blue"]) + + +def test_enum_init_from_python_invalid() -> None: + class Color(int, enum.Enum): + RED = 1 + GREEN = 2 + BLUE = 3 + + with pytest.raises( + TypeError, + match="Enum categories must be strings", + ): + pl.Enum(Color) + + # flag/int enums + for EnumBase in (enum.Flag, enum.IntFlag, enum.IntEnum): + + class Color(EnumBase): # type: ignore[no-redef,misc,valid-type] + RED = enum.auto() + GREEN = enum.auto() + BLUE = enum.auto() + + base_name = EnumBase.__name__ + + with pytest.raises( + TypeError, + match=f"Enum categories must be strings; Python `enum.{base_name}` values are integers", + ): + pl.Enum(Color) def test_enum_non_existent() -> None: with pytest.raises( InvalidOperationError, - match=re.escape( - "conversion from `str` to `enum` failed in column '' for 1 out of 4 values: [\"c\"]" - ), + match="conversion from `str` to `enum` failed in column '' for 1 out of 4 values: \\[\"c\"\\]", ): pl.Series([None, "a", "b", "c"], dtype=pl.Enum(categories=["a", "b"]))