Skip to content

Commit

Permalink
test(python): Improve hypothesis strategy for decimals (#16001)
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego authored May 2, 2024
1 parent ebd8aec commit 5062732
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 46 deletions.
8 changes: 2 additions & 6 deletions py-polars/polars/testing/parametric/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@
from polars.string_cache import StringCache
from polars.testing.parametric.strategies import (
_flexhash,
all_strategies,
between,
create_array_strategy,
create_list_strategy,
dtype_strategies,
scalar_strategies,
)

Expand Down Expand Up @@ -381,11 +381,7 @@ def draw_series(draw: DrawFn) -> Series:
if strategy is None:
if series_dtype is Datetime or series_dtype is Duration:
series_dtype = series_dtype(random.choice(_time_units)) # type: ignore[operator]
dtype_strategy = all_strategies[
series_dtype
if series_dtype in all_strategies
else series_dtype.base_type()
]
dtype_strategy = draw(dtype_strategies(series_dtype))
else:
dtype_strategy = strategy

Expand Down
89 changes: 49 additions & 40 deletions py-polars/polars/testing/parametric/strategies.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from datetime import datetime, timedelta
from decimal import Decimal as PyDecimal
from itertools import chain
from random import choice, randint, shuffle
from string import ascii_uppercase
Expand All @@ -14,6 +15,7 @@
Sequence,
)

import hypothesis.strategies as st
from hypothesis.strategies import (
SearchStrategy,
binary,
Expand All @@ -22,7 +24,6 @@
composite,
dates,
datetimes,
decimals,
floats,
from_type,
integers,
Expand Down Expand Up @@ -56,13 +57,11 @@
UInt16,
UInt32,
UInt64,
is_polars_dtype,
)
from polars.type_aliases import PolarsDataType

if TYPE_CHECKING:
import sys
from decimal import Decimal as PyDecimal

from hypothesis.strategies import DrawFn

Expand All @@ -72,6 +71,26 @@
from typing_extensions import Self


@composite
def dtype_strategies(draw: DrawFn, dtype: PolarsDataType) -> SearchStrategy[Any]:
"""Returns a strategy which generates valid values for the given data type."""
if (strategy := all_strategies.get(dtype)) is not None:
return strategy
elif (strategy_base := all_strategies.get(dtype.base_type())) is not None:
return strategy_base

if dtype == Decimal:
return draw(
decimal_strategies(
precision=getattr(dtype, "precision", None),
scale=getattr(dtype, "scale", None),
)
)
else:
msg = f"unsupported data type: {dtype}"
raise TypeError(msg)


def between(draw: DrawFn, type_: type, min_: Any, max_: Any) -> Any:
"""Draw a value in a given range from a type-inferred strategy."""
strategy_init = from_type(type_).function # type: ignore[attr-defined]
Expand Down Expand Up @@ -117,19 +136,28 @@ def between(draw: DrawFn, type_: type, min_: Any, max_: Any) -> Any:


@composite
def strategy_decimal(draw: DrawFn) -> PyDecimal:
"""Draw a decimal value, varying the number of decimal places."""
places = draw(integers(min_value=0, max_value=18))
return draw(
# TODO: once fixed, re-enable decimal nan/inf values...
# (see https://github.com/pola-rs/polars/issues/8421)
decimals(
allow_nan=False,
allow_infinity=False,
min_value=-(2**66),
max_value=(2**66) - 1,
places=places,
)
def decimal_strategies(
draw: DrawFn, precision: int | None = None, scale: int | None = None
) -> SearchStrategy[PyDecimal]:
"""Returns a strategy which generates instances of Python `Decimal`."""
if precision is None:
precision = draw(integers(min_value=scale or 1, max_value=38))
if scale is None:
scale = draw(integers(min_value=0, max_value=precision))

exclusive_limit = PyDecimal(f"1E+{precision - scale}")
epsilon = PyDecimal(f"1E-{scale}")
limit = exclusive_limit - epsilon
if limit == exclusive_limit: # Limit cannot be set exactly due to precision issues
multiplier = PyDecimal("1") - PyDecimal("1E-20") # 0.999...
limit = limit * multiplier

return st.decimals(
allow_nan=False,
allow_infinity=False,
min_value=-limit,
max_value=limit,
places=scale,
)


Expand Down Expand Up @@ -272,34 +300,15 @@ def update(self, items: StrategyLookup) -> Self: # type: ignore[override]
Categorical: strategy_categorical,
String: strategy_string,
Binary: strategy_binary,
Decimal: strategy_decimal(),
}
)
nested_strategies: StrategyLookup = StrategyLookup()


def _get_strategy_dtypes(
*,
base_type: bool = False,
excluding: tuple[PolarsDataType] | PolarsDataType | None = None,
) -> list[PolarsDataType]:
"""
Get a list of all the dtypes for which we have a strategy.
Parameters
----------
base_type
If True, return the base types for each dtype (eg:`List(String)` → `List`).
excluding
A dtype or sequence of dtypes to omit from the results.
"""
excluding = (excluding,) if is_polars_dtype(excluding) else (excluding or ()) # type: ignore[assignment]
def _get_strategy_dtypes() -> list[PolarsDataType]:
"""Get a list of all the dtypes for which we have a strategy."""
strategy_dtypes = list(chain(scalar_strategies.keys(), nested_strategies.keys()))
return [
(tp.base_type() if base_type else tp)
for tp in strategy_dtypes
if tp not in excluding # type: ignore[operator]
]
return [tp.base_type() for tp in strategy_dtypes]


def _flexhash(elem: Any) -> int:
Expand Down Expand Up @@ -351,7 +360,7 @@ def create_array_strategy(
width = randint(a=1, b=8)

if inner_dtype is None:
strats = list(_get_strategy_dtypes(base_type=True))
strats = list(_get_strategy_dtypes())
shuffle(strats)
inner_dtype = choice(strats)

Expand Down Expand Up @@ -431,7 +440,7 @@ def create_list_strategy(
raise ValueError(msg)

if inner_dtype is None:
strats = list(_get_strategy_dtypes(base_type=True))
strats = list(_get_strategy_dtypes())
shuffle(strats)
inner_dtype = choice(strats)
if size:
Expand Down
6 changes: 6 additions & 0 deletions py-polars/tests/unit/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import gc
import os
import random
import string
import sys
Expand All @@ -11,6 +12,11 @@
import pytest

import polars as pl
from polars.testing.parametric.profiles import load_profile

load_profile(
profile=os.environ.get("POLARS_HYPOTHESIS_PROFILE", "fast"), # type: ignore[arg-type]
)


@pytest.fixture()
Expand Down

0 comments on commit 5062732

Please sign in to comment.