From 5480c5a52a030450f5eddd604de872490d0535c1 Mon Sep 17 00:00:00 2001 From: Marshall Crumiller Date: Thu, 15 Aug 2024 11:07:46 -0400 Subject: [PATCH] Enable direct lit --- py-polars/polars/functions/lit.py | 16 ++++++---------- py-polars/src/functions/lazy.rs | 5 ++++- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/py-polars/polars/functions/lit.py b/py-polars/polars/functions/lit.py index 5bf327257694..700c65dbee7c 100644 --- a/py-polars/polars/functions/lit.py +++ b/py-polars/polars/functions/lit.py @@ -7,8 +7,6 @@ import polars._reexport as pl from polars._utils.convert import ( - date_to_int, - datetime_to_int, time_to_int, timedelta_to_int, ) @@ -79,8 +77,7 @@ def lit( if isinstance(value, datetime): if dtype == Date: - dt_int = date_to_int(value.date()) - return lit(dt_int).cast(Date) + return wrap_expr(plr.lit(value.date(), allow_object=False)) # parse time unit if dtype is not None and (tu := getattr(dtype, "time_unit", "us")) is not None: @@ -109,8 +106,7 @@ def lit( raise TypeError(msg) dt_utc = value.replace(tzinfo=timezone.utc) - dt_int = datetime_to_int(dt_utc, time_unit) - expr = lit(dt_int).cast(Datetime(time_unit)) + expr = wrap_expr(plr.lit(dt_utc, allow_object=False)).cast(Datetime(time_unit)) if tz is not None: expr = expr.dt.replace_time_zone( tz, ambiguous="earliest" if value.fold == 0 else "latest" @@ -134,14 +130,14 @@ def lit( if dtype == Datetime: time_unit = getattr(dtype, "time_unit", "us") or "us" dt_utc = datetime(value.year, value.month, value.day) - dt_int = datetime_to_int(dt_utc, time_unit) - expr = lit(dt_int).cast(Datetime(time_unit)) + expr = wrap_expr(plr.lit(dt_utc, allow_object=False)).cast( + Datetime(time_unit) + ) if (time_zone := getattr(dtype, "time_zone", None)) is not None: expr = expr.dt.replace_time_zone(str(time_zone)) return expr else: - date_int = date_to_int(value) - return lit(date_int).cast(Date) + return wrap_expr(plr.lit(value, allow_object=False)) elif isinstance(value, pl.Series): value = value._s diff --git a/py-polars/src/functions/lazy.rs b/py-polars/src/functions/lazy.rs index 49325e617170..aa098aee2cb0 100644 --- a/py-polars/src/functions/lazy.rs +++ b/py-polars/src/functions/lazy.rs @@ -435,7 +435,10 @@ pub fn lit(value: &Bound<'_, PyAny>, allow_object: bool) -> PyResult { Ok(dsl::lit(Null {}).into()) } else if let Ok(value) = value.downcast::() { Ok(dsl::lit(value.as_bytes()).into()) - } else if value.get_type().qualname().unwrap() == "Decimal" { + } else if matches!( + value.get_type().qualname().unwrap().as_str(), + "date" | "datetime" | "Decimal" + ) { let av = py_object_to_any_value(value, true)?; Ok(Expr::Literal(LiteralValue::try_from(av).unwrap()).into()) } else if allow_object {