diff --git a/py-polars/polars/testing/parametric/primitives.py b/py-polars/polars/testing/parametric/primitives.py index bf5c30e19c88..fc41af3e8b7a 100644 --- a/py-polars/polars/testing/parametric/primitives.py +++ b/py-polars/polars/testing/parametric/primitives.py @@ -28,10 +28,10 @@ from polars.string_cache import StringCache from polars.testing.parametric.strategies import ( _flexhash, - all_strategies, between, create_array_strategy, create_list_strategy, + dtype_strategies, scalar_strategies, ) @@ -381,11 +381,7 @@ def draw_series(draw: DrawFn) -> Series: if strategy is None: if series_dtype is Datetime or series_dtype is Duration: series_dtype = series_dtype(random.choice(_time_units)) # type: ignore[operator] - dtype_strategy = all_strategies[ - series_dtype - if series_dtype in all_strategies - else series_dtype.base_type() - ] + dtype_strategy = draw(dtype_strategies(series_dtype)) else: dtype_strategy = strategy diff --git a/py-polars/polars/testing/parametric/strategies.py b/py-polars/polars/testing/parametric/strategies.py index 7e03e3808e36..2cfc3626c478 100644 --- a/py-polars/polars/testing/parametric/strategies.py +++ b/py-polars/polars/testing/parametric/strategies.py @@ -1,6 +1,7 @@ from __future__ import annotations from datetime import datetime, timedelta +from decimal import Decimal as PyDecimal from itertools import chain from random import choice, randint, shuffle from string import ascii_uppercase @@ -14,6 +15,7 @@ Sequence, ) +import hypothesis.strategies as st from hypothesis.strategies import ( SearchStrategy, binary, @@ -22,7 +24,6 @@ composite, dates, datetimes, - decimals, floats, from_type, integers, @@ -56,13 +57,11 @@ UInt16, UInt32, UInt64, - is_polars_dtype, ) from polars.type_aliases import PolarsDataType if TYPE_CHECKING: import sys - from decimal import Decimal as PyDecimal from hypothesis.strategies import DrawFn @@ -72,6 +71,26 @@ from typing_extensions import Self +@composite +def dtype_strategies(draw: DrawFn, dtype: PolarsDataType) -> SearchStrategy[Any]: + """Returns a strategy which generates valid values for the given data type.""" + if (strategy := all_strategies.get(dtype)) is not None: + return strategy + elif (strategy_base := all_strategies.get(dtype.base_type())) is not None: + return strategy_base + + if dtype == Decimal: + return draw( + decimal_strategies( + precision=getattr(dtype, "precision", None), + scale=getattr(dtype, "scale", None), + ) + ) + else: + msg = f"unsupported data type: {dtype}" + raise TypeError(msg) + + def between(draw: DrawFn, type_: type, min_: Any, max_: Any) -> Any: """Draw a value in a given range from a type-inferred strategy.""" strategy_init = from_type(type_).function # type: ignore[attr-defined] @@ -117,19 +136,28 @@ def between(draw: DrawFn, type_: type, min_: Any, max_: Any) -> Any: @composite -def strategy_decimal(draw: DrawFn) -> PyDecimal: - """Draw a decimal value, varying the number of decimal places.""" - places = draw(integers(min_value=0, max_value=18)) - return draw( - # TODO: once fixed, re-enable decimal nan/inf values... - # (see https://github.com/pola-rs/polars/issues/8421) - decimals( - allow_nan=False, - allow_infinity=False, - min_value=-(2**66), - max_value=(2**66) - 1, - places=places, - ) +def decimal_strategies( + draw: DrawFn, precision: int | None = None, scale: int | None = None +) -> SearchStrategy[PyDecimal]: + """Returns a strategy which generates instances of Python `Decimal`.""" + if precision is None: + precision = draw(integers(min_value=scale or 1, max_value=38)) + if scale is None: + scale = draw(integers(min_value=0, max_value=precision)) + + exclusive_limit = PyDecimal(f"1E+{precision - scale}") + epsilon = PyDecimal(f"1E-{scale}") + limit = exclusive_limit - epsilon + if limit == exclusive_limit: # Limit cannot be set exactly due to precision issues + multiplier = PyDecimal("1") - PyDecimal("1E-20") # 0.999... + limit = limit * multiplier + + return st.decimals( + allow_nan=False, + allow_infinity=False, + min_value=-limit, + max_value=limit, + places=scale, ) @@ -272,34 +300,15 @@ def update(self, items: StrategyLookup) -> Self: # type: ignore[override] Categorical: strategy_categorical, String: strategy_string, Binary: strategy_binary, - Decimal: strategy_decimal(), } ) nested_strategies: StrategyLookup = StrategyLookup() -def _get_strategy_dtypes( - *, - base_type: bool = False, - excluding: tuple[PolarsDataType] | PolarsDataType | None = None, -) -> list[PolarsDataType]: - """ - Get a list of all the dtypes for which we have a strategy. - - Parameters - ---------- - base_type - If True, return the base types for each dtype (eg:`List(String)` → `List`). - excluding - A dtype or sequence of dtypes to omit from the results. - """ - excluding = (excluding,) if is_polars_dtype(excluding) else (excluding or ()) # type: ignore[assignment] +def _get_strategy_dtypes() -> list[PolarsDataType]: + """Get a list of all the dtypes for which we have a strategy.""" strategy_dtypes = list(chain(scalar_strategies.keys(), nested_strategies.keys())) - return [ - (tp.base_type() if base_type else tp) - for tp in strategy_dtypes - if tp not in excluding # type: ignore[operator] - ] + return [tp.base_type() for tp in strategy_dtypes] def _flexhash(elem: Any) -> int: @@ -351,7 +360,7 @@ def create_array_strategy( width = randint(a=1, b=8) if inner_dtype is None: - strats = list(_get_strategy_dtypes(base_type=True)) + strats = list(_get_strategy_dtypes()) shuffle(strats) inner_dtype = choice(strats) @@ -431,7 +440,7 @@ def create_list_strategy( raise ValueError(msg) if inner_dtype is None: - strats = list(_get_strategy_dtypes(base_type=True)) + strats = list(_get_strategy_dtypes()) shuffle(strats) inner_dtype = choice(strats) if size: diff --git a/py-polars/tests/unit/conftest.py b/py-polars/tests/unit/conftest.py index a335d6c88447..cc4fe80fc1e1 100644 --- a/py-polars/tests/unit/conftest.py +++ b/py-polars/tests/unit/conftest.py @@ -1,6 +1,7 @@ from __future__ import annotations import gc +import os import random import string import sys @@ -11,6 +12,11 @@ import pytest import polars as pl +from polars.testing.parametric.profiles import load_profile + +load_profile( + profile=os.environ.get("POLARS_HYPOTHESIS_PROFILE", "fast"), # type: ignore[arg-type] +) @pytest.fixture()