Skip to content

Commit

Permalink
feat(python): Infer Enum dtype on DataFrame cols constructed from P…
Browse files Browse the repository at this point in the history
…ython Enums
  • Loading branch information
alexander-beedie committed Dec 6, 2024
1 parent dc54699 commit ed030a3
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 20 deletions.
9 changes: 9 additions & 0 deletions py-polars/polars/_utils/construction/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import contextlib
from collections.abc import Generator, Iterator
from datetime import date, datetime, time, timedelta
from enum import Enum as PyEnum
from itertools import islice
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -124,6 +125,14 @@ def sequence_to_pyseries(
) and not isinstance(value, int):
python_dtype = dtype_to_py_type(dtype) # type: ignore[arg-type]

# if values are enums, infer and load the appropriate dtype/values
if issubclass(type(value), PyEnum):
if dtype is None and python_dtype is None:
with contextlib.suppress(TypeError):
dtype = Enum(type(value))
if not isinstance(value, (str, int)):
values = [v.value for v in values]

# physical branch
# flat data
if (
Expand Down
3 changes: 1 addition & 2 deletions py-polars/polars/datatypes/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,8 +629,7 @@ def __init__(self, categories: Series | Iterable[str] | type[enum.Enum]) -> None
raise TypeError(msg)

enum_values = [
(v if isinstance(v, str) else v.value)
for v in categories.__members__.values()
getattr(v, "value", v) for v in categories.__members__.values()
]
categories = pl.Series(values=enum_values)
elif not isinstance(categories, pl.Series):
Expand Down
7 changes: 2 additions & 5 deletions py-polars/polars/functions/lit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import polars._reexport as pl
from polars._utils.wrap import wrap_expr
from polars.datatypes import Date, Datetime, Duration, Enum
from polars.datatypes import Date, Datetime, Duration
from polars.dependencies import _check_for_numpy
from polars.dependencies import numpy as np

Expand Down Expand Up @@ -150,10 +150,7 @@ def lit(
)

elif isinstance(value, enum.Enum):
lit_value = value.value
if dtype is None and isinstance(value, str):
dtype = Enum(m.value for m in type(value))
return lit(lit_value, dtype=dtype)
return lit(value.value, dtype=dtype)

if dtype:
return wrap_expr(plr.lit(value, allow_object, is_scalar=True)).cast(dtype)
Expand Down
69 changes: 69 additions & 0 deletions py-polars/tests/unit/datatypes/test_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,3 +612,72 @@ def test_enum_19269() -> None:

assert out.to_dict(as_series=False) == {"a": ["X", "Y"], "b": ["X", "Z"]}
assert out.dtypes == [en, en]


@pytest.mark.parametrize(
"EnumBase",
[
(enum.Enum,),
(enum.StrEnum,),
(str, enum.Enum),
]
if sys.version_info >= (3, 11)
else [
(enum.Enum,),
(str, enum.Enum),
],
)
def test_init_frame_from_enums(EnumBase: tuple[type, ...]) -> None:
class Portfolio(*EnumBase): # type: ignore[misc]
TECH = "Technology"
RETAIL = "Retail"
OTHER = "Other"

# confirm that we can infer the enum dtype from various enum bases
df = pl.DataFrame(
{"trade_id": [123, 456], "portfolio": [Portfolio.OTHER, Portfolio.TECH]}
)
expected = pl.DataFrame(
{"trade_id": [123, 456], "portfolio": ["Other", "Technology"]},
schema={
"trade_id": pl.Int64,
"portfolio": pl.Enum(["Technology", "Retail", "Other"]),
},
)
assert_frame_equal(expected, df)

# if schema indicates string, ensure we do *not* convert to enum
df = pl.DataFrame(
{
"trade_id": [123, 456, 789],
"portfolio": [Portfolio.OTHER, Portfolio.TECH, Portfolio.RETAIL],
},
schema_overrides={"portfolio": pl.String},
)
assert df.schema == {"trade_id": pl.Int64, "portfolio": pl.String}


@pytest.mark.parametrize(
"EnumBase",
[
(enum.Enum,),
(enum.Flag,),
(enum.IntEnum,),
(enum.IntFlag,),
(int, enum.Enum),
],
)
def test_init_series_from_int_enum(EnumBase: tuple[type, ...]) -> None:
# note: we do not support integer enums as polars enums,
# but we should be able to load the values

class Number(*EnumBase): # type: ignore[misc]
ONE = 1
TWO = 2
FOUR = 4
EIGHT = 8

s = pl.Series(values=[Number.EIGHT, Number.TWO, Number.FOUR])

expected = pl.Series(values=[8, 2, 4], dtype=pl.Int64)
assert_series_equal(expected, s)
62 changes: 49 additions & 13 deletions py-polars/tests/unit/functions/test_lit.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import enum
import sys
from datetime import datetime, timedelta
from decimal import Decimal
from typing import TYPE_CHECKING, Any
Expand Down Expand Up @@ -105,32 +106,67 @@ def test_lit_unsupported_type() -> None:
pl.lit(pl.LazyFrame({"a": [1, 2, 3]}))


def test_lit_enum_input_16668() -> None:
@pytest.mark.parametrize(
"EnumBase",
[
(enum.Enum,),
(enum.StrEnum,),
(str, enum.Enum),
]
if sys.version_info >= (3, 11)
else [
(enum.Enum,),
(str, enum.Enum),
],
)
def test_lit_enum_input_16668(EnumBase: tuple[type, ...]) -> None:
# https://github.com/pola-rs/polars/issues/16668

class State(str, enum.Enum):
VIC = "victoria"
NSW = "new south wales"
class State(*EnumBase): # type: ignore[misc]
NSW = "New South Wales"
QLD = "Queensland"
VIC = "Victoria"

# validate that frame schema has inferred the enum
df = pl.DataFrame({"state": [State.NSW, State.VIC]})
assert df.schema == {
"state": pl.Enum(["New South Wales", "Queensland", "Victoria"])
}

# check use of enum as lit/constraint
value = State.VIC
expected = "Victoria"

result = pl.lit(value)
assert pl.select(result).dtypes[0] == pl.Enum(["victoria", "new south wales"])
assert pl.select(result).item() == "victoria"
for lit_value in (
pl.lit(value),
pl.lit(value.value), # type: ignore[attr-defined]
):
assert pl.select(lit_value).item() == expected
assert df.filter(state=value).item() == expected
assert df.filter(state=lit_value).item() == expected

result = pl.lit(value, dtype=pl.String)
assert pl.select(result).dtypes[0] == pl.String
assert pl.select(result).item() == "victoria"
assert df.filter(pl.col("state") == State.QLD).is_empty()
assert df.filter(pl.col("state") != State.QLD).height == 2


def test_lit_enum_input_non_string() -> None:
@pytest.mark.parametrize(
"EnumBase",
[
(enum.Enum,),
(enum.Flag,),
(enum.IntEnum,),
(enum.IntFlag,),
(int, enum.Enum),
],
)
def test_lit_enum_input_non_string(EnumBase: tuple[type, ...]) -> None:
# https://github.com/pola-rs/polars/issues/16668

class State(int, enum.Enum):
class Number(*EnumBase): # type: ignore[misc]
ONE = 1
TWO = 2

value = State.ONE
value = Number.ONE

result = pl.lit(value)
assert pl.select(result).dtypes[0] == pl.Int32
Expand Down

0 comments on commit ed030a3

Please sign in to comment.