Skip to content

Commit

Permalink
fix: Defer credential provider resolution to take place at query coll…
Browse files Browse the repository at this point in the history
…ection instead of construction (#21225)
  • Loading branch information
nameexhaustion authored Feb 13, 2025
1 parent ed19c73 commit 7f22a25
Show file tree
Hide file tree
Showing 25 changed files with 838 additions and 417 deletions.
2 changes: 1 addition & 1 deletion crates/polars-core/src/serde/df.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ impl<'de> Deserialize<'de> for DataFrame {
where
D: Deserializer<'de>,
{
deserialize_map_bytes(deserializer, &mut |b| {
deserialize_map_bytes(deserializer, |b| {
let v = &mut b.as_ref();
Self::deserialize_from_reader(v)
})?
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-core/src/serde/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ impl<'de> Deserialize<'de> for Series {
where
D: Deserializer<'de>,
{
deserialize_map_bytes(deserializer, &mut |b| {
deserialize_map_bytes(deserializer, |b| {
let v = &mut b.as_ref();
Self::deserialize_from_reader(v)
})?
Expand Down
146 changes: 99 additions & 47 deletions crates/polars-io/src/cloud/credential_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ pub use object_store::gcp::GcpCredential;
use polars_core::config;
use polars_error::{polars_bail, PolarsResult};
#[cfg(feature = "python")]
use polars_utils::python_function::PythonFunction;
use polars_utils::python_function::PythonObject;
#[cfg(feature = "python")]
use python_impl::PythonCredentialProvider;

Expand Down Expand Up @@ -43,23 +43,32 @@ impl PlCredentialProvider {
Self::Function(CredentialProviderFunction(Arc::new(func)))
}

/// Intended to be called with an internal `CredentialProviderBuilder` from
/// py-polars.
#[cfg(feature = "python")]
pub fn from_python_func(func: PythonFunction) -> Self {
Self::Python(python_impl::PythonCredentialProvider(Arc::new(func)))
}

#[cfg(feature = "python")]
pub fn from_python_func_object(func: pyo3::PyObject) -> Self {
Self::Python(python_impl::PythonCredentialProvider(Arc::new(
PythonFunction(func),
pub fn from_python_builder(func: pyo3::PyObject) -> Self {
Self::Python(python_impl::PythonCredentialProvider::Builder(Arc::new(
PythonObject(func),
)))
}

pub(super) fn func_addr(&self) -> usize {
match self {
Self::Function(CredentialProviderFunction(v)) => Arc::as_ptr(v) as *const () as usize,
#[cfg(feature = "python")]
Self::Python(PythonCredentialProvider(v)) => Arc::as_ptr(v) as *const () as usize,
Self::Python(v) => v.func_addr(),
}
}

/// Python passes a `CredentialProviderBuilder`, this calls the builder to build the final
/// credential provider.
///
/// This returns `Option` as the auto-initialization case is fallible and falls back to None.
pub(crate) fn try_into_initialized(self) -> PolarsResult<Option<Self>> {
match self {
Self::Function(_) => Ok(Some(self)),
#[cfg(feature = "python")]
Self::Python(v) => Ok(v.try_into_initialized()?.map(Self::Python)),
}
}
}
Expand Down Expand Up @@ -452,8 +461,8 @@ mod python_impl {
use std::hash::Hash;
use std::sync::Arc;

use polars_error::PolarsError;
use polars_utils::python_function::PythonFunction;
use polars_error::{to_compute_err, PolarsError, PolarsResult};
use polars_utils::python_function::PythonObject;
use pyo3::exceptions::PyValueError;
use pyo3::pybacked::PyBackedStr;
use pyo3::types::{PyAnyMethods, PyDict, PyDictMethods};
Expand All @@ -462,11 +471,71 @@ mod python_impl {
use super::IntoCredentialProvider;

#[derive(Clone, Debug)]
pub struct PythonCredentialProvider(pub(super) Arc<PythonFunction>);
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum PythonCredentialProvider {
#[cfg_attr(
feature = "serde",
serde(
serialize_with = "PythonObject::serialize_with_pyversion",
deserialize_with = "PythonObject::deserialize_with_pyversion"
)
)]
/// Indicates `py_object` is a `CredentialProviderBuilder`.
Builder(Arc<PythonObject>),
#[cfg_attr(
feature = "serde",
serde(
serialize_with = "PythonObject::serialize_with_pyversion",
deserialize_with = "PythonObject::deserialize_with_pyversion"
)
)]
/// Indicates `py_object` is an instantiated credential provider
Provider(Arc<PythonObject>),
}

impl From<PythonFunction> for PythonCredentialProvider {
fn from(value: PythonFunction) -> Self {
Self(Arc::new(value))
impl PythonCredentialProvider {
/// Performs initialization if necessary.
///
/// This exists as a separate step that must be called beforehand. This approach is easier
/// as the alternative is to refactor the `IntoCredentialProvider` trait to return
/// `PolarsResult<Option<T>>` for every single function.
pub(super) fn try_into_initialized(self) -> PolarsResult<Option<Self>> {
match self {
Self::Builder(py_object) => {
let opt_initialized_py_object = Python::with_gil(|py| {
let build_fn = py_object.getattr(py, "build_credential_provider")?;

let v = build_fn.call0(py)?;
let v = (!v.is_none(py)).then_some(v);

pyo3::PyResult::Ok(v)
})
.map_err(to_compute_err)?;

Ok(opt_initialized_py_object
.map(PythonObject)
.map(Arc::new)
.map(Self::Provider))
},
Self::Provider(_) => {
// Note: We don't expect to hit here.
Ok(Some(self))
},
}
}

fn unwrap_as_provider(self) -> Arc<PythonObject> {
match self {
Self::Builder(_) => panic!(),
Self::Provider(v) => v,
}
}

pub(super) fn func_addr(&self) -> usize {
(match self {
Self::Builder(v) => Arc::as_ptr(v),
Self::Provider(v) => Arc::as_ptr(v),
}) as *const () as usize
}
}

Expand All @@ -479,8 +548,10 @@ mod python_impl {
CredentialProviderFunction, ObjectStoreCredential,
};

let func = self.unwrap_as_provider();

CredentialProviderFunction(Arc::new(move || {
let func = self.0.clone();
let func = func.clone();
Box::pin(async move {
let mut credentials = object_store::aws::AwsCredential {
key_id: String::new(),
Expand Down Expand Up @@ -554,8 +625,10 @@ mod python_impl {
CredentialProviderFunction, ObjectStoreCredential,
};

let func = self.unwrap_as_provider();

CredentialProviderFunction(Arc::new(move || {
let func = self.0.clone();
let func = func.clone();
Box::pin(async move {
let mut credentials = None;

Expand Down Expand Up @@ -621,8 +694,10 @@ mod python_impl {
CredentialProviderFunction, ObjectStoreCredential,
};

let func = self.unwrap_as_provider();

CredentialProviderFunction(Arc::new(move || {
let func = self.0.clone();
let func = func.clone();
Box::pin(async move {
let mut credentials = object_store::gcp::GcpCredential {
bearer: String::new(),
Expand Down Expand Up @@ -666,11 +741,14 @@ mod python_impl {
}
}

// Note: We don't consider `is_builder` for hash/eq - we don't expect the same Arc<PythonObject>
// to be referenced as both true and false from the `is_builder` field.

impl Eq for PythonCredentialProvider {}

impl PartialEq for PythonCredentialProvider {
fn eq(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.0, &other.0)
self.func_addr() == other.func_addr()
}
}

Expand All @@ -680,33 +758,7 @@ mod python_impl {
// * Inner is an `Arc`
// * Visibility is limited to super
// * No code in `mod python_impl` or `super` mutates the Arc inner.
state.write_usize(Arc::as_ptr(&self.0) as *const () as usize)
}
}

#[cfg(feature = "serde")]
mod _serde_impl {
use polars_utils::python_function::PySerializeWrap;

use super::PythonCredentialProvider;

impl serde::Serialize for PythonCredentialProvider {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
PySerializeWrap(self.0.as_ref()).serialize(serializer)
}
}

impl<'a> serde::Deserialize<'a> for PythonCredentialProvider {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'a>,
{
PySerializeWrap::<super::PythonFunction>::deserialize(deserializer)
.map(|x| x.0.into())
}
state.write_usize(self.func_addr())
}
}
}
Expand Down
35 changes: 19 additions & 16 deletions crates/polars-io/src/cloud/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ use regex::Regex;
#[cfg(feature = "http")]
use reqwest::header::HeaderMap;
#[cfg(feature = "serde")]
use serde::Deserializer;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[cfg(feature = "cloud")]
use url::Url;
Expand Down Expand Up @@ -80,19 +78,11 @@ pub struct CloudOptions {
pub file_cache_ttl: u64,
pub(crate) config: Option<CloudConfig>,
#[cfg(feature = "cloud")]
#[cfg_attr(feature = "serde", serde(deserialize_with = "deserialize_or_default"))]
/// Note: In most cases you will want to access this via [`CloudOptions::initialized_credential_provider`]
/// rather than directly.
pub(crate) credential_provider: Option<PlCredentialProvider>,
}

#[cfg(all(feature = "serde", feature = "cloud"))]
fn deserialize_or_default<'de, D>(deserializer: D) -> Result<Option<PlCredentialProvider>, D::Error>
where
D: Deserializer<'de>,
{
type T = Option<PlCredentialProvider>;
T::deserialize(deserializer).or_else(|_| Ok(Default::default()))
}

impl Default for CloudOptions {
fn default() -> Self {
Self::default_static_ref().clone()
Expand Down Expand Up @@ -392,7 +382,7 @@ impl CloudOptions {

let builder = builder.with_retry(get_retry_config(self.max_retries));

let builder = if let Some(v) = self.credential_provider.clone() {
let builder = if let Some(v) = self.initialized_credential_provider()? {
builder.with_credentials(v.into_aws_provider())
} else {
builder
Expand Down Expand Up @@ -438,7 +428,7 @@ impl CloudOptions {
.with_url(url)
.with_retry(get_retry_config(self.max_retries));

let builder = if let Some(v) = self.credential_provider.clone() {
let builder = if let Some(v) = self.initialized_credential_provider()? {
if verbose {
eprintln!(
"[CloudOptions::build_azure]: Using credential provider {:?}",
Expand Down Expand Up @@ -470,7 +460,9 @@ impl CloudOptions {
pub fn build_gcp(&self, url: &str) -> PolarsResult<impl object_store::ObjectStore> {
use super::credential_provider::IntoCredentialProvider;

let builder = if self.credential_provider.is_none() {
let credential_provider = self.initialized_credential_provider()?;

let builder = if credential_provider.is_none() {
GoogleCloudStorageBuilder::from_env()
} else {
GoogleCloudStorageBuilder::new()
Expand All @@ -491,7 +483,7 @@ impl CloudOptions {
.with_url(url)
.with_retry(get_retry_config(self.max_retries));

let builder = if let Some(v) = self.credential_provider.clone() {
let builder = if let Some(v) = credential_provider.clone() {
builder.with_credentials(v.into_gcp_provider())
} else {
builder
Expand Down Expand Up @@ -629,6 +621,17 @@ impl CloudOptions {
},
}
}

/// Python passes a credential provider builder that needs to be called to get the actual credential
/// provider.
#[cfg(feature = "cloud")]
fn initialized_credential_provider(&self) -> PolarsResult<Option<PlCredentialProvider>> {
if let Some(v) = self.credential_provider.clone() {
v.try_into_initialized()
} else {
Ok(None)
}
}
}

#[cfg(feature = "cloud")]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ where
{
use polars_utils::pl_serialize::deserialize_map_bytes;

deserialize_map_bytes(deserializer, &mut |b| {
deserialize_map_bytes(deserializer, |b| {
let mut b = b.as_ref();
let mut protocol = TCompactInputProtocol::new(&mut b, usize::MAX);
ColumnChunk::read_from_in_protocol(&mut protocol).map_err(D::Error::custom)
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-plan/src/dsl/expr_dyn_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ impl<'a> Deserialize<'a> for SpecialEq<Arc<dyn ColumnsUdf>> {
use serde::de::Error;
#[cfg(feature = "python")]
{
deserialize_map_bytes(deserializer, &mut |buf| {
deserialize_map_bytes(deserializer, |buf| {
if buf.starts_with(crate::dsl::python_dsl::PYTHON_SERDE_MAGIC_BYTE_MARK) {
let udf = crate::dsl::python_dsl::PythonUdfExpression::try_deserialize(&buf)
.map_err(|e| D::Error::custom(format!("{e}")))?;
Expand Down Expand Up @@ -407,7 +407,7 @@ impl<'a> Deserialize<'a> for GetOutput {
use serde::de::Error;
#[cfg(feature = "python")]
{
deserialize_map_bytes(deserializer, &mut |buf| {
deserialize_map_bytes(deserializer, |buf| {
if buf.starts_with(self::python_dsl::PYTHON_SERDE_MAGIC_BYTE_MARK) {
let get_output = self::python_dsl::PythonGetOutput::try_deserialize(&buf)
.map_err(|e| D::Error::custom(format!("{e}")))?;
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-python/src/catalog/unity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ impl PyCatalogClient {
parse_cloud_options(storage_location, cloud_options.unwrap_or_default())?
.with_max_retries(retries)
.with_credential_provider(
credential_provider.map(PlCredentialProvider::from_python_func_object),
credential_provider.map(PlCredentialProvider::from_python_builder),
);

Ok(
Expand Down
6 changes: 3 additions & 3 deletions crates/polars-python/src/dataframe/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ impl PyDataFrame {
cloud_options
.with_max_retries(retries)
.with_credential_provider(
credential_provider.map(PlCredentialProvider::from_python_func_object),
credential_provider.map(PlCredentialProvider::from_python_builder),
),
)
} else {
Expand Down Expand Up @@ -424,7 +424,7 @@ impl PyDataFrame {
cloud_options
.with_max_retries(retries)
.with_credential_provider(
credential_provider.map(PlCredentialProvider::from_python_func_object),
credential_provider.map(PlCredentialProvider::from_python_builder),
),
)
} else {
Expand Down Expand Up @@ -517,7 +517,7 @@ impl PyDataFrame {
cloud_options
.with_max_retries(retries)
.with_credential_provider(
credential_provider.map(PlCredentialProvider::from_python_func_object),
credential_provider.map(PlCredentialProvider::from_python_builder),
),
)
} else {
Expand Down
Loading

0 comments on commit 7f22a25

Please sign in to comment.