From 174aadf9dd1b86917007b0a7d2ec3bdf74edfc42 Mon Sep 17 00:00:00 2001 From: Nathaniel Cook Date: Wed, 29 Jan 2025 11:49:15 -0700 Subject: [PATCH] fix: use invoke_batch on UDFs as invoke is deprecated Datafusion recently deprecated the invoke method on ScalarUDFs. This change updates to be following best practices going forward. --- pipeline/src/cid_string.rs | 149 +++++++++++++++++++++++++++++++++++-- 1 file changed, 142 insertions(+), 7 deletions(-) diff --git a/pipeline/src/cid_string.rs b/pipeline/src/cid_string.rs index 6184d8d6..e1207818 100644 --- a/pipeline/src/cid_string.rs +++ b/pipeline/src/cid_string.rs @@ -5,7 +5,7 @@ use std::{any::Any, sync::Arc}; use cid::Cid; use datafusion::{ arrow::{ - array::{ArrayIter, ListBuilder, StringBuilder}, + array::{ListBuilder, StringBuilder}, datatypes::DataType, }, common::{ @@ -15,7 +15,10 @@ use datafusion::{ logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, TypeSignature, Volatility}, }; -/// ScalarUDF to convert a binary CID into a string for easier inspection. #[derive(Debug)] +// Length of a typical Ceramic CID as a UTF8 string in bytes. +const CID_STRING_BYTES: usize = 60; + +/// ScalarUDF to convert a binary CID into a string for easier inspection. #[derive(Debug)] pub struct CidString { signature: Signature, @@ -52,10 +55,14 @@ impl ScalarUDFImpl for CidString { fn return_type(&self, _args: &[DataType]) -> datafusion::common::Result { Ok(DataType::Utf8) } - fn invoke(&self, args: &[ColumnarValue]) -> datafusion::common::Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + number_rows: usize, + ) -> datafusion::error::Result { let args = ColumnarValue::values_to_arrays(args)?; let cids = as_binary_array(&args[0])?; - let mut strs = StringBuilder::new(); + let mut strs = StringBuilder::with_capacity(number_rows, CID_STRING_BYTES * number_rows); for cid in cids { if let Some(cid) = cid { strs.append_value( @@ -108,11 +115,23 @@ impl ScalarUDFImpl for CidStringList { fn return_type(&self, _args: &[DataType]) -> datafusion::common::Result { Ok(DataType::new_list(DataType::Utf8, true)) } - fn invoke(&self, args: &[ColumnarValue]) -> datafusion::common::Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + number_rows: usize, + ) -> datafusion::error::Result { let args = ColumnarValue::values_to_arrays(args)?; let all_cids = as_list_array(&args[0])?; - let mut strs = ListBuilder::new(StringBuilder::new()); - for cids in ArrayIter::new(all_cids) { + // Count the number of cids before allocating. + let cid_count = all_cids + .iter() + .map(|list| list.map(|l| l.len()).unwrap_or(0)) + .sum(); + let mut strs = ListBuilder::with_capacity( + StringBuilder::with_capacity(cid_count, CID_STRING_BYTES * cid_count), + number_rows, + ); + for cids in all_cids.iter() { if let Some(cids) = cids { let cids = as_binary_array(&cids)?; for cid in cids { @@ -134,3 +153,119 @@ impl ScalarUDFImpl for CidStringList { Ok(ColumnarValue::Array(Arc::new(strs.finish()))) } } + +#[cfg(test)] +mod tests { + use super::{CidString, CidStringList}; + + use std::{str::FromStr as _, sync::Arc}; + + use arrow::{ + array::{ArrayRef, ListBuilder}, + util::pretty::pretty_format_batches, + }; + use cid::Cid; + use datafusion::{ + arrow::array::{BinaryBuilder, StructArray}, + logical_expr::{expr::ScalarFunction, ScalarUDF}, + prelude::{col, Expr, SessionContext}, + }; + use expect_test::expect; + use test_log::test; + + #[test(tokio::test)] + async fn cid_string() -> anyhow::Result<()> { + let mut cids = BinaryBuilder::new(); + cids.append_value( + Cid::from_str("baeabeihqujx543mwoqik7lltjjtobjmdewsygp5s34mydy3web3dxhoide")? + .to_bytes(), + ); + cids.append_value( + Cid::from_str("baeabeih6c7vkvijmvu22oqwwirsukfci46evfp6rrydq4bavmoxw5yrzaq")? + .to_bytes(), + ); + cids.append_value( + Cid::from_str("baeabeiayakacgjkantxde2puuxqfrrq72cih4gykxov7gpdiypoz4d3oya")? + .to_bytes(), + ); + cids.append_value( + Cid::from_str("baeabeiaclaqgmmybteclwcb6p6hth24pwezy3chaghryslttst34d2lrny")? + .to_bytes(), + ); + cids.append_value( + Cid::from_str("baeabeibou6io3hgsapzk5kzdj6gxw5whbsa4oruhndgfonei35bugpqswm")? + .to_bytes(), + ); + + let batch = StructArray::try_from(vec![("cid", Arc::new(cids.finish()) as ArrayRef)])?; + let cid_string = Arc::new(ScalarUDF::from(CidString::new())); + let ctx = SessionContext::new(); + let output = ctx + .read_batch(batch.into())? + .select(vec![Expr::ScalarFunction(ScalarFunction::new_udf( + cid_string, + vec![col("cid")], + ))])? + .collect() + .await?; + let output = pretty_format_batches(&output)?; + expect![[r#" + +-------------------------------------------------------------+ + | cid_string(?table?.cid) | + +-------------------------------------------------------------+ + | baeabeihqujx543mwoqik7lltjjtobjmdewsygp5s34mydy3web3dxhoide | + | baeabeih6c7vkvijmvu22oqwwirsukfci46evfp6rrydq4bavmoxw5yrzaq | + | baeabeiayakacgjkantxde2puuxqfrrq72cih4gykxov7gpdiypoz4d3oya | + | baeabeiaclaqgmmybteclwcb6p6hth24pwezy3chaghryslttst34d2lrny | + | baeabeibou6io3hgsapzk5kzdj6gxw5whbsa4oruhndgfonei35bugpqswm | + +-------------------------------------------------------------+"#]] + .assert_eq(&output.to_string()); + Ok(()) + } + #[test(tokio::test)] + async fn cid_string_list() -> anyhow::Result<()> { + let mut cids = ListBuilder::new(BinaryBuilder::new()); + cids.values().append_value( + Cid::from_str("baeabeihqujx543mwoqik7lltjjtobjmdewsygp5s34mydy3web3dxhoide")? + .to_bytes(), + ); + cids.values().append_value( + Cid::from_str("baeabeih6c7vkvijmvu22oqwwirsukfci46evfp6rrydq4bavmoxw5yrzaq")? + .to_bytes(), + ); + cids.values().append_value( + Cid::from_str("baeabeiayakacgjkantxde2puuxqfrrq72cih4gykxov7gpdiypoz4d3oya")? + .to_bytes(), + ); + cids.append(true); + cids.values().append_value( + Cid::from_str("baeabeiaclaqgmmybteclwcb6p6hth24pwezy3chaghryslttst34d2lrny")? + .to_bytes(), + ); + cids.values().append_value( + Cid::from_str("baeabeibou6io3hgsapzk5kzdj6gxw5whbsa4oruhndgfonei35bugpqswm")? + .to_bytes(), + ); + cids.append(true); + let batch = StructArray::try_from(vec![("cids", Arc::new(cids.finish()) as ArrayRef)])?; + let cid_string_list = Arc::new(ScalarUDF::from(CidStringList::new())); + let ctx = SessionContext::new(); + let output = ctx + .read_batch(batch.into())? + .select(vec![Expr::ScalarFunction(ScalarFunction::new_udf( + cid_string_list, + vec![col("cids")], + ))])? + .collect() + .await?; + let output = pretty_format_batches(&output)?; + expect![[r#" + +-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + | array_cid_string(?table?.cids) | + +-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + | [baeabeihqujx543mwoqik7lltjjtobjmdewsygp5s34mydy3web3dxhoide, baeabeih6c7vkvijmvu22oqwwirsukfci46evfp6rrydq4bavmoxw5yrzaq, baeabeiayakacgjkantxde2puuxqfrrq72cih4gykxov7gpdiypoz4d3oya] | + | [baeabeiaclaqgmmybteclwcb6p6hth24pwezy3chaghryslttst34d2lrny, baeabeibou6io3hgsapzk5kzdj6gxw5whbsa4oruhndgfonei35bugpqswm] | + +-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+"#]].assert_eq(&output.to_string()); + Ok(()) + } +}