Skip to content

Commit

Permalink
fix(python): Address indexing edge-case with numpy arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie committed Nov 21, 2024
1 parent f718909 commit 3f45cfb
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 42 deletions.
6 changes: 5 additions & 1 deletion py-polars/polars/_utils/getitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,9 @@ def _select_columns(
raise TypeError(msg)

elif _check_for_numpy(key) and isinstance(key, np.ndarray):
if key.ndim != 1:
if key.ndim == 0:
key = np.atleast_1d(key)
elif key.ndim != 1:
msg = "multi-dimensional NumPy arrays not supported as index"
raise TypeError(msg)

Expand Down Expand Up @@ -397,6 +399,8 @@ def _convert_np_ndarray_to_indices(arr: np.ndarray[Any, Any], size: int) -> Seri
# - Signed numpy array indexes are converted pl.UInt32 (polars) or
# pl.UInt64 (polars_u64_idx) after negative indexes are converted
# to absolute indexes.
if arr.ndim == 0:
arr = np.atleast_1d(arr)
if arr.ndim != 1:
msg = "only 1D NumPy arrays can be treated as indices"
raise TypeError(msg)
Expand Down
92 changes: 51 additions & 41 deletions py-polars/tests/unit/dataframe/test_getitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,46 +294,6 @@ def test_df_getitem() -> None:
# empty list with column selector drops rows but keeps columns
assert_frame_equal(df[empty, :], df[:0])

# numpy array: assumed to be row indices if integers, or columns if strings

# numpy array: positive idxs and empty idx
for np_dtype in (
np.int8,
np.int16,
np.int32,
np.int64,
np.uint8,
np.uint16,
np.uint32,
np.uint64,
):
assert_frame_equal(
df[np.array([1, 0, 3, 2, 3, 0], dtype=np_dtype)],
pl.DataFrame(
{"a": [2.0, 1.0, 4.0, 3.0, 4.0, 1.0], "b": [4, 3, 6, 5, 6, 3]}
),
)
assert df[np.array([], dtype=np_dtype)].columns == ["a", "b"]

# numpy array: positive and negative idxs.
for np_dtype in (np.int8, np.int16, np.int32, np.int64):
assert_frame_equal(
df[np.array([-1, 0, -3, -2, 3, -4], dtype=np_dtype)],
pl.DataFrame(
{"a": [4.0, 1.0, 2.0, 3.0, 4.0, 1.0], "b": [6, 3, 4, 5, 6, 3]}
),
)

# note that we cannot use floats (even if they could be casted to integer without
# loss)
with pytest.raises(TypeError):
_ = df[np.array([1.0])]

with pytest.raises(
TypeError, match="multi-dimensional NumPy arrays not supported as index"
):
df[np.array([[0], [1]])]

# sequences (lists or tuples; tuple only if length != 2)
# if strings or list of expressions, assumed to be column names
# if bools, assumed to be a row mask
Expand Down Expand Up @@ -392,7 +352,57 @@ def test_df_getitem() -> None:
df[:, [True, False, True]]


def test_df_getitem2() -> None:
def test_df_getitem_numpy() -> None:
# nupmy getitem: assumed to be row indices if integers, or columns if strings
df = pl.DataFrame({"a": [1.0, 2.0, 3.0, 4.0], "b": [3, 4, 5, 6]})

# numpy array: positive idxs and empty idx
for np_dtype in (
np.int8,
np.int16,
np.int32,
np.int64,
np.uint8,
np.uint16,
np.uint32,
np.uint64,
):
assert_frame_equal(
df[np.array([1, 0, 3, 2, 3, 0], dtype=np_dtype)],
pl.DataFrame(
{"a": [2.0, 1.0, 4.0, 3.0, 4.0, 1.0], "b": [4, 3, 6, 5, 6, 3]}
),
)
assert df[np.array([], dtype=np_dtype)].columns == ["a", "b"]

# numpy array: positive and negative idxs.
for np_dtype in (np.int8, np.int16, np.int32, np.int64):
assert_frame_equal(
df[np.array([-1, 0, -3, -2, 3, -4], dtype=np_dtype)],
pl.DataFrame(
{"a": [4.0, 1.0, 2.0, 3.0, 4.0, 1.0], "b": [6, 3, 4, 5, 6, 3]}
),
)

# zero-dimensional array indexing is equivalent to int row selection
assert_frame_equal(df[np.array(0)], pl.DataFrame({"a": [1.0], "b": [3]}))
assert_frame_equal(df[np.array(1)], pl.DataFrame({"a": [2.0], "b": [4]}))

# note that we cannot use floats (even if they could be cast to int without loss)
with pytest.raises(
TypeError,
match="cannot select columns using NumPy array of type float",
):
_ = df[np.array([1.0])]

with pytest.raises(
TypeError,
match="multi-dimensional NumPy arrays not supported as index",
):
df[np.array([[0], [1]])]


def test_df_getitem_extended() -> None:
df = pl.DataFrame({"a": [1, 2, 3], "b": [1.0, 2.0, 3.0], "c": ["a", "b", "c"]})

# select columns by mask
Expand Down
14 changes: 14 additions & 0 deletions py-polars/tests/unit/interop/numpy/test_to_numpy_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from numpy.testing import assert_array_equal

import polars as pl
from polars.testing import assert_series_equal
from polars.testing.parametric import series

if TYPE_CHECKING:
Expand Down Expand Up @@ -136,6 +137,19 @@ def test_series_to_numpy_date() -> None:
assert_allow_copy_false_raises(s)


def test_series_to_numpy_multi_dimensional_init() -> None:
s = pl.Series(np.atleast_3d(np.array([-10.5, 0.0, 10.5])))
assert_series_equal(
s,
pl.Series(
[[[-10.5], [0.0], [10.5]]],
dtype=pl.Array(pl.Float64, shape=(3, 1)),
),
)
s = pl.Series(np.array(0), dtype=pl.Int32)
assert_series_equal(s, pl.Series([0], dtype=pl.Int32))


@pytest.mark.parametrize(
("dtype", "expected_dtype"),
[
Expand Down
8 changes: 8 additions & 0 deletions py-polars/tests/unit/series/test_getitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,14 @@ def test_series_getitem_multiple_indices(indices: Any) -> None:
assert_series_equal(result, expected)


def test_series_getitem_numpy() -> None:
s = pl.Series([9, 8, 7])

assert s[np.array([0, 2])].to_list() == [9, 7]
assert s[np.array([-1, -3])].to_list() == [7, 9]
assert s[np.array(-2)].to_list() == [8]


@pytest.mark.parametrize(
("input", "match"),
[
Expand Down

0 comments on commit 3f45cfb

Please sign in to comment.