Skip to content

Commit

Permalink
Remove ddof parameter for pl.corr in Rust
Browse files Browse the repository at this point in the history
  • Loading branch information
flowlight0 committed Dec 6, 2024
1 parent 84a68a7 commit b880194
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 33 deletions.
3 changes: 1 addition & 2 deletions crates/polars-compute/src/var_cov.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,7 @@ impl PearsonState {
self.mean_y = new_mean_y;
}

pub fn finalize(&self, _ddof: u8) -> f64 {
// The division by sample_weight - ddof on both sides cancels out.
pub fn finalize(&self) -> f64 {
let denom = (self.dp_xx * self.dp_yy).sqrt();
if denom == 0.0 {
f64::NAN
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-lazy/src/tests/arity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ fn test_pearson_corr() -> PolarsResult<()> {
.lazy()
.group_by_stable([col("uid")])
// a double aggregation expression.
.agg([pearson_corr(col("day"), col("cumcases"), 1).alias("pearson_corr")])
.agg([pearson_corr(col("day"), col("cumcases")).alias("pearson_corr")])
.collect()?;
let s = out.column("pearson_corr")?.f64()?;
assert!((s.get(0).unwrap() - 0.997176).abs() < 0.000001);
Expand All @@ -25,7 +25,7 @@ fn test_pearson_corr() -> PolarsResult<()> {
.lazy()
.group_by_stable([col("uid")])
// a double aggregation expression.
.agg([pearson_corr(col("day"), col("cumcases"), 1)
.agg([pearson_corr(col("day"), col("cumcases"))
.pow(2.0)
.alias("pearson_corr")])
.collect()
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-ops/src/chunked_array/cov.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ where
}

/// Compute the pearson correlation between two columns.
pub fn pearson_corr<T>(a: &ChunkedArray<T>, b: &ChunkedArray<T>, ddof: u8) -> Option<f64>
pub fn pearson_corr<T>(a: &ChunkedArray<T>, b: &ChunkedArray<T>) -> Option<f64>
where
T: PolarsNumericType,
T::Native: AsPrimitive<f64>,
Expand All @@ -30,5 +30,5 @@ where
for (a, b) in a.downcast_iter().zip(b.downcast_iter()) {
out.combine(&polars_compute::var_cov::pearson_corr(a, b))
}
Some(out.finalize(ddof))
Some(out.finalize())
}
24 changes: 11 additions & 13 deletions crates/polars-plan/src/dsl/function_expr/correlation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,9 @@ impl Display for CorrelationMethod {

pub(super) fn corr(s: &[Column], ddof: u8, method: CorrelationMethod) -> PolarsResult<Column> {
match method {
CorrelationMethod::Pearson => pearson_corr(s, ddof),
CorrelationMethod::Pearson => pearson_corr(s),
#[cfg(all(feature = "rank", feature = "propagate_nans"))]
CorrelationMethod::SpearmanRank(propagate_nans) => {
spearman_rank_corr(s, ddof, propagate_nans)
},
CorrelationMethod::SpearmanRank(propagate_nans) => spearman_rank_corr(s, propagate_nans),
CorrelationMethod::Covariance => covariance(s, ddof),
}
}
Expand Down Expand Up @@ -61,32 +59,32 @@ fn covariance(s: &[Column], ddof: u8) -> PolarsResult<Column> {
Ok(Column::new(name, &[ret]))
}

fn pearson_corr(s: &[Column], ddof: u8) -> PolarsResult<Column> {
fn pearson_corr(s: &[Column]) -> PolarsResult<Column> {
let a = &s[0];
let b = &s[1];
let name = PlSmallStr::from_static("pearson_corr");

use polars_ops::chunked_array::cov::pearson_corr;
let ret = match a.dtype() {
DataType::Float32 => {
let ret = pearson_corr(a.f32().unwrap(), b.f32().unwrap(), ddof).map(|v| v as f32);
let ret = pearson_corr(a.f32().unwrap(), b.f32().unwrap()).map(|v| v as f32);
return Ok(Column::new(name.clone(), &[ret]));
},
DataType::Float64 => pearson_corr(a.f64().unwrap(), b.f64().unwrap(), ddof),
DataType::Int32 => pearson_corr(a.i32().unwrap(), b.i32().unwrap(), ddof),
DataType::Int64 => pearson_corr(a.i64().unwrap(), b.i64().unwrap(), ddof),
DataType::UInt32 => pearson_corr(a.u32().unwrap(), b.u32().unwrap(), ddof),
DataType::Float64 => pearson_corr(a.f64().unwrap(), b.f64().unwrap()),
DataType::Int32 => pearson_corr(a.i32().unwrap(), b.i32().unwrap()),
DataType::Int64 => pearson_corr(a.i64().unwrap(), b.i64().unwrap()),
DataType::UInt32 => pearson_corr(a.u32().unwrap(), b.u32().unwrap()),
_ => {
let a = a.cast(&DataType::Float64)?;
let b = b.cast(&DataType::Float64)?;
pearson_corr(a.f64().unwrap(), b.f64().unwrap(), ddof)
pearson_corr(a.f64().unwrap(), b.f64().unwrap())
},
};
Ok(Column::new(name, &[ret]))
}

#[cfg(all(feature = "rank", feature = "propagate_nans"))]
fn spearman_rank_corr(s: &[Column], ddof: u8, propagate_nans: bool) -> PolarsResult<Column> {
fn spearman_rank_corr(s: &[Column], propagate_nans: bool) -> PolarsResult<Column> {
use polars_core::utils::coalesce_nulls_columns;
use polars_ops::chunked_array::nan_propagating_aggregate::nan_max_s;
let a = &s[0];
Expand Down Expand Up @@ -134,5 +132,5 @@ fn spearman_rank_corr(s: &[Column], ddof: u8, propagate_nans: bool) -> PolarsRes
)
.into();

pearson_corr(&[a_rank, b_rank], ddof)
pearson_corr(&[a_rank, b_rank])
}
14 changes: 4 additions & 10 deletions crates/polars-plan/src/dsl/functions/correlation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,11 @@ pub fn cov(a: Expr, b: Expr, ddof: u8) -> Expr {
}

/// Compute the pearson correlation between two columns.
///
/// # Arguments
/// * ddof
/// Delta degrees of freedom
pub fn pearson_corr(a: Expr, b: Expr, ddof: u8) -> Expr {
pub fn pearson_corr(a: Expr, b: Expr) -> Expr {
let input = vec![a, b];
let function = FunctionExpr::Correlation {
method: CorrelationMethod::Pearson,
ddof,
ddof: 0u8,
};
Expr::Function {
input,
Expand All @@ -45,18 +41,16 @@ pub fn pearson_corr(a: Expr, b: Expr, ddof: u8) -> Expr {
/// Compute the spearman rank correlation between two columns.
/// Missing data will be excluded from the computation.
/// # Arguments
/// * ddof
/// Delta degrees of freedom
/// * propagate_nans
/// If `true` any `NaN` encountered will lead to `NaN` in the output.
/// If to `false` then `NaN` are regarded as larger than any finite number
/// and thus lead to the highest rank.
#[cfg(all(feature = "rank", feature = "propagate_nans"))]
pub fn spearman_rank_corr(a: Expr, b: Expr, ddof: u8, propagate_nans: bool) -> Expr {
pub fn spearman_rank_corr(a: Expr, b: Expr, propagate_nans: bool) -> Expr {
let input = vec![a, b];
let function = FunctionExpr::Correlation {
method: CorrelationMethod::SpearmanRank(propagate_nans),
ddof,
ddof: 0u8,
};
Expr::Function {
input,
Expand Down
8 changes: 4 additions & 4 deletions crates/polars-python/src/functions/lazy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -504,8 +504,8 @@ pub fn map_mul(
}

#[pyfunction]
pub fn pearson_corr(a: PyExpr, b: PyExpr, ddof: u8) -> PyExpr {
dsl::pearson_corr(a.inner, b.inner, ddof).into()
pub fn pearson_corr(a: PyExpr, b: PyExpr) -> PyExpr {
dsl::pearson_corr(a.inner, b.inner).into()
}

#[pyfunction]
Expand Down Expand Up @@ -537,10 +537,10 @@ pub fn repeat(value: PyExpr, n: PyExpr, dtype: Option<Wrap<DataType>>) -> PyResu
}

#[pyfunction]
pub fn spearman_rank_corr(a: PyExpr, b: PyExpr, ddof: u8, propagate_nans: bool) -> PyExpr {
pub fn spearman_rank_corr(a: PyExpr, b: PyExpr, propagate_nans: bool) -> PyExpr {
#[cfg(feature = "propagate_nans")]
{
dsl::spearman_rank_corr(a.inner, b.inner, ddof, propagate_nans).into()
dsl::spearman_rank_corr(a.inner, b.inner, propagate_nans).into()
}
#[cfg(not(feature = "propagate_nans"))]
{
Expand Down

0 comments on commit b880194

Please sign in to comment.