Skip to content

Commit

Permalink
fix(python): Respect dtype and strict in pl.Series's constructo…
Browse files Browse the repository at this point in the history
…r for pyarrow arrays, numpy arrays, and pyarrow-backed pandas (#15962)
  • Loading branch information
Zhengbo Wang authored May 9, 2024
1 parent 72185cf commit f18c306
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 9 deletions.
19 changes: 15 additions & 4 deletions py-polars/polars/_utils/construction/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,14 +402,15 @@ def pandas_to_pyseries(
values: pd.Series[Any] | pd.Index[Any] | pd.DatetimeIndex,
dtype: PolarsDataType | None = None,
*,
strict: bool = True,
nan_to_null: bool = True,
) -> PySeries:
"""Construct a PySeries from a pandas Series or DatetimeIndex."""
if not name and values.name is not None:
name = str(values.name)
if is_simple_numpy_backed_pandas_series(values):
return pl.Series(
name, values.to_numpy(), dtype=dtype, nan_to_null=nan_to_null
name, values.to_numpy(), dtype=dtype, nan_to_null=nan_to_null, strict=strict
)._s
if not _PYARROW_AVAILABLE:
msg = (
Expand All @@ -419,11 +420,21 @@ def pandas_to_pyseries(
)
raise ImportError(msg)
return arrow_to_pyseries(
name, plc.pandas_series_to_arrow(values, nan_to_null=nan_to_null)
name,
plc.pandas_series_to_arrow(values, nan_to_null=nan_to_null),
dtype=dtype,
strict=strict,
)


def arrow_to_pyseries(name: str, values: pa.Array, *, rechunk: bool = True) -> PySeries:
def arrow_to_pyseries(
name: str,
values: pa.Array,
dtype: PolarsDataType | None = None,
*,
strict: bool = True,
rechunk: bool = True,
) -> PySeries:
"""Construct a PySeries from an Arrow array."""
array = plc.coerce_arrow(values)

Expand Down Expand Up @@ -460,7 +471,7 @@ def arrow_to_pyseries(name: str, values: pa.Array, *, rechunk: bool = True) -> P
if rechunk:
pys.rechunk(in_place=True)

return pys
return pys.cast(dtype, strict=strict) if dtype is not None else pys


def numpy_to_pyseries(
Expand Down
6 changes: 3 additions & 3 deletions py-polars/polars/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,17 +343,17 @@ def __init__(
return

if dtype is not None:
self._s = self.cast(dtype, strict=True)._s
self._s = self.cast(dtype, strict=strict)._s

elif _check_for_pyarrow(values) and isinstance(
values, (pa.Array, pa.ChunkedArray)
):
self._s = arrow_to_pyseries(name, values)
self._s = arrow_to_pyseries(name, values, dtype=dtype, strict=strict)

elif _check_for_pandas(values) and isinstance(
values, (pd.Series, pd.Index, pd.DatetimeIndex)
):
self._s = pandas_to_pyseries(name, values, dtype=dtype)
self._s = pandas_to_pyseries(name, values, dtype=dtype, strict=strict)

elif _is_generator(values):
self._s = iterable_to_pyseries(name, values, dtype=dtype, strict=strict)
Expand Down
43 changes: 41 additions & 2 deletions py-polars/tests/unit/series/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2340,5 +2340,44 @@ def test_search_sorted(


def test_series_from_pandas_with_dtype() -> None:
s = pl.Series("foo", pd.Series([1, 2, 3]), pl.Float32)
assert_series_equal(s, pl.Series("foo", [1, 2, 3], dtype=pl.Float32))
expected = pl.Series("foo", [1, 2, 3], dtype=pl.Int8)
s = pl.Series("foo", pd.Series([1, 2, 3]), pl.Int8)
assert_series_equal(s, expected)
s = pl.Series("foo", pd.Series([1, 2, 3], dtype="Int16"), pl.Int8)
assert_series_equal(s, expected)

with pytest.raises(pl.ComputeError, match="conversion from"):
pl.Series("foo", pd.Series([-1, 2, 3]), pl.UInt8)
s = pl.Series("foo", pd.Series([-1, 2, 3]), pl.UInt8, strict=False)
assert s.to_list() == [None, 2, 3]
assert s.dtype == pl.UInt8

with pytest.raises(pl.ComputeError, match="conversion from"):
pl.Series("foo", pd.Series([-1, 2, 3], dtype="Int8"), pl.UInt8)
s = pl.Series("foo", pd.Series([-1, 2, 3], dtype="Int8"), pl.UInt8, strict=False)
assert s.to_list() == [None, 2, 3]
assert s.dtype == pl.UInt8


def test_series_from_pyarrow_with_dtype() -> None:
s = pl.Series("foo", pa.array([-1, 2, 3]), pl.Int8)
assert_series_equal(s, pl.Series("foo", [-1, 2, 3], dtype=pl.Int8))

with pytest.raises(pl.ComputeError, match="conversion from"):
pl.Series("foo", pa.array([-1, 2, 3]), pl.UInt8)

s = pl.Series("foo", pa.array([-1, 2, 3]), dtype=pl.UInt8, strict=False)
assert s.to_list() == [None, 2, 3]
assert s.dtype == pl.UInt8


def test_series_from_numpy_with_dtye() -> None:
s = pl.Series("foo", np.array([-1, 2, 3]), pl.Int8)
assert_series_equal(s, pl.Series("foo", [-1, 2, 3], dtype=pl.Int8))

with pytest.raises(pl.ComputeError, match="conversion from"):
pl.Series("foo", np.array([-1, 2, 3]), pl.UInt8)

s = pl.Series("foo", np.array([-1, 2, 3]), dtype=pl.UInt8, strict=False)
assert s.to_list() == [None, 2, 3]
assert s.dtype == pl.UInt8

0 comments on commit f18c306

Please sign in to comment.