Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flatten error handling, reducing generics #54

Merged
merged 2 commits into from
Oct 6, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Simplify a lot of the error propagation
smklein committed Oct 4, 2023
commit 92b4770238b8421bfbb534f4c863f7e1b7ab83ea
1 change: 1 addition & 0 deletions examples/customize_connection.rs
Original file line number Diff line number Diff line change
@@ -15,6 +15,7 @@ impl bb8::CustomizeConnection<DieselPgConn, ConnectionError> for ConnectionCusto
connection
.batch_execute_async("please execute some raw sql for me")
.await
.map_err(ConnectionError::from)
}
}

5 changes: 2 additions & 3 deletions examples/usage.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use async_bb8_diesel::{
AsyncConnection, AsyncRunQueryDsl, AsyncSaveChangesDsl, ConnectionError, OptionalExtension,
};
use async_bb8_diesel::{AsyncConnection, AsyncRunQueryDsl, AsyncSaveChangesDsl, ConnectionError};
use diesel::OptionalExtension;
use diesel::{pg::PgConnection, prelude::*};

table! {
163 changes: 104 additions & 59 deletions src/async_traits.rs
Original file line number Diff line number Diff line change
@@ -18,42 +18,37 @@ use tokio::task::spawn_blocking;

/// An async variant of [`diesel::connection::SimpleConnection`].
#[async_trait]
pub trait AsyncSimpleConnection<Conn, ConnErr>
pub trait AsyncSimpleConnection<Conn>
where
Conn: 'static + SimpleConnection,
{
async fn batch_execute_async(&self, query: &str) -> Result<(), ConnErr>;
async fn batch_execute_async(&self, query: &str) -> Result<(), DieselError>;
}

/// An async variant of [`diesel::connection::Connection`].
#[async_trait]
pub trait AsyncConnection<Conn, ConnErr>: AsyncSimpleConnection<Conn, ConnErr>
pub trait AsyncConnection<Conn>: AsyncSimpleConnection<Conn>
where
Conn: 'static + DieselConnection,
ConnErr: From<DieselError> + Send + 'static,
Self: Send,
{
type OwnedConnection: Sync + Send + 'static;

#[doc(hidden)]
async fn get_owned_connection(&self) -> Result<Self::OwnedConnection, ConnErr>;
async fn get_owned_connection(&self) -> Self::OwnedConnection;
#[doc(hidden)]
fn as_sync_conn(owned: &Self::OwnedConnection) -> MutexGuard<'_, Conn>;
#[doc(hidden)]
fn as_async_conn(owned: &Self::OwnedConnection) -> &SingleConnection<Conn>;

/// Runs the function `f` in an context where blocking is safe.
///
/// Any error may be propagated through `f`, as long as that
/// error type may be constructed from `ConnErr` (as that error
/// type may also be generated).
async fn run<R, E, Func>(&self, f: Func) -> Result<R, E>
where
R: Send + 'static,
E: From<ConnErr> + Send + 'static,
E: Send + 'static,
Func: FnOnce(&mut Conn) -> Result<R, E> + Send + 'static,
{
let connection = self.get_owned_connection().await?;
let connection = self.get_owned_connection().await;
Self::run_with_connection(connection, f).await
}

@@ -64,7 +59,7 @@ where
) -> Result<R, E>
where
R: Send + 'static,
E: From<ConnErr> + Send + 'static,
E: Send + 'static,
Func: FnOnce(&mut Conn) -> Result<R, E> + Send + 'static,
{
spawn_blocking(move || f(&mut *Self::as_sync_conn(&connection)))
@@ -79,7 +74,7 @@ where
) -> Result<R, E>
where
R: Send + 'static,
E: From<ConnErr> + Send + 'static,
E: Send + 'static,
Func: FnOnce(&mut Conn) -> Result<R, E> + Send + 'static,
{
spawn_blocking(move || f(&mut *Self::as_sync_conn(&connection)))
@@ -90,57 +85,58 @@ where
async fn transaction<R, E, Func>(&self, f: Func) -> Result<R, E>
where
R: Send + 'static,
E: From<DieselError> + From<ConnErr> + Send + 'static,
E: From<DieselError> + Send + 'static,
Func: FnOnce(&mut Conn) -> Result<R, E> + Send + 'static,
{
self.run(|conn| conn.transaction(|c| f(c))).await
}

// TODO(for both the transaction functions below): The ideal interface would
// pass the "async_conn" object to the underlying function "f" by reference.
//
// This would prevent the user-supplied closure + future from using the
// connection *beyond* the duration of the transaction, which would be
// bad.
//
// However, I'm struggling to get these lifetimes to work properly. If
// you can figure out a way to convince that the reference lives long
// enough to be referenceable by a Future, but short enough that we can
// guarantee it doesn't live persist after this function returns, feel
// free to make that change.

async fn transaction_async<R, E, Func, Fut, 'a>(&'a self, f: Func) -> Result<R, E>
where
R: Send + 'static,
E: From<DieselError> + From<ConnErr> + Send,
E: From<DieselError> + Send + 'static,
Fut: Future<Output = Result<R, E>> + Send,
Func: FnOnce(SingleConnection<Conn>) -> Fut + Send,
{
// Check out a connection once, and use it for the duration of the
// operation.
let conn = Arc::new(self.get_owned_connection().await?);
let conn = Arc::new(self.get_owned_connection().await);

// This function mimics the implementation of:
// https://docs.diesel.rs/master/diesel/connection/trait.TransactionManager.html#method.transaction
//
// However, it modifies all callsites to instead issue
// known-to-be-synchronous operations from an asynchronous context.
Self::run_with_shared_connection(conn.clone(), |conn| {
Conn::TransactionManager::begin_transaction(conn).map_err(ConnErr::from)
Conn::TransactionManager::begin_transaction(conn).map_err(E::from)
})
.await?;

// TODO: The ideal interface would pass the "async_conn" object to the
// underlying function "f" by reference.
//
// This would prevent the user-supplied closure + future from using the
// connection *beyond* the duration of the transaction, which would be
// bad.
//
// However, I'm struggling to get these lifetimes to work properly. If
// you can figure out a way to convince that the reference lives long
// enough to be referenceable by a Future, but short enough that we can
// guarantee it doesn't live persist after this function returns, feel
// free to make that change.
let async_conn = SingleConnection(Self::as_async_conn(&conn).0.clone());
match f(async_conn).await {
Ok(value) => {
Self::run_with_shared_connection(conn.clone(), |conn| {
Conn::TransactionManager::commit_transaction(conn).map_err(ConnErr::from)
Conn::TransactionManager::commit_transaction(conn).map_err(E::from)
})
.await?;
Ok(value)
}
Err(user_error) => {
match Self::run_with_shared_connection(conn.clone(), |conn| {
Conn::TransactionManager::rollback_transaction(conn).map_err(ConnErr::from)
Conn::TransactionManager::rollback_transaction(conn).map_err(E::from)
})
.await
{
@@ -150,116 +146,165 @@ where
}
}
}

async fn transaction_async_with_retry<R, E, Func, Fut, RetryFunc, RetryFut, 'a>(
&'a self,
f: Func,
should_retry: RetryFunc,
) -> Result<R, E>
where
R: Send + 'static,
E: From<DieselError> + Send + 'static,
Fut: Future<Output = Result<R, E>> + Send,
Func: Fn(SingleConnection<Conn>) -> Fut + Send + Sync,
RetryFut: Future<Output = bool> + Send,
RetryFunc: Fn() -> RetryFut + Send + Sync,
{
// Check out a connection once, and use it for the duration of the
// operation.
let conn = Arc::new(self.get_owned_connection().await);

// This function mimics the implementation of:
// https://docs.diesel.rs/master/diesel/connection/trait.TransactionManager.html#method.transaction
//
// However, it modifies all callsites to instead issue
// known-to-be-synchronous operations from an asynchronous context.
Self::run_with_shared_connection(conn.clone(), |conn| {
Conn::TransactionManager::begin_transaction(conn).map_err(E::from)
})
.await?;

loop {
let async_conn = SingleConnection(Self::as_async_conn(&conn).0.clone());
match f(async_conn).await {
Ok(value) => {
Self::run_with_shared_connection(conn.clone(), |conn| {
Conn::TransactionManager::commit_transaction(conn).map_err(E::from)
})
.await?;
return Ok(value);
}
Err(user_error) => {
match Self::run_with_shared_connection(conn.clone(), |conn| {
Conn::TransactionManager::rollback_transaction(conn).map_err(E::from)
})
.await
{
Ok(()) => return Err(user_error),
Err(err) => {
return Err(err.into());
}
}
}
}
}
}
}

/// An async variant of [`diesel::query_dsl::RunQueryDsl`].
#[async_trait]
pub trait AsyncRunQueryDsl<Conn, AsyncConn, E>
pub trait AsyncRunQueryDsl<Conn, AsyncConn>
where
Conn: 'static + DieselConnection,
{
async fn execute_async(self, asc: &AsyncConn) -> Result<usize, E>
async fn execute_async(self, asc: &AsyncConn) -> Result<usize, DieselError>
where
Self: ExecuteDsl<Conn>;

async fn load_async<U>(self, asc: &AsyncConn) -> Result<Vec<U>, E>
async fn load_async<U>(self, asc: &AsyncConn) -> Result<Vec<U>, DieselError>
where
U: Send + 'static,
Self: LoadQuery<'static, Conn, U>;

async fn get_result_async<U>(self, asc: &AsyncConn) -> Result<U, E>
async fn get_result_async<U>(self, asc: &AsyncConn) -> Result<U, DieselError>
where
U: Send + 'static,
Self: LoadQuery<'static, Conn, U>;

async fn get_results_async<U>(self, asc: &AsyncConn) -> Result<Vec<U>, E>
async fn get_results_async<U>(self, asc: &AsyncConn) -> Result<Vec<U>, DieselError>
where
U: Send + 'static,
Self: LoadQuery<'static, Conn, U>;

async fn first_async<U>(self, asc: &AsyncConn) -> Result<U, E>
async fn first_async<U>(self, asc: &AsyncConn) -> Result<U, DieselError>
where
U: Send + 'static,
Self: LimitDsl,
Limit<Self>: LoadQuery<'static, Conn, U>;
}

#[async_trait]
impl<T, AsyncConn, Conn, E> AsyncRunQueryDsl<Conn, AsyncConn, E> for T
impl<T, AsyncConn, Conn> AsyncRunQueryDsl<Conn, AsyncConn> for T
where
T: 'static + Send + RunQueryDsl<Conn>,
Conn: 'static + DieselConnection,
AsyncConn: Send + Sync + AsyncConnection<Conn, E>,
E: From<DieselError> + Send + 'static,
AsyncConn: Send + Sync + AsyncConnection<Conn>,
{
async fn execute_async(self, asc: &AsyncConn) -> Result<usize, E>
async fn execute_async(self, asc: &AsyncConn) -> Result<usize, DieselError>
where
Self: ExecuteDsl<Conn>,
{
asc.run(|conn| self.execute(conn).map_err(E::from)).await
asc.run(|conn| self.execute(conn)).await
}

async fn load_async<U>(self, asc: &AsyncConn) -> Result<Vec<U>, E>
async fn load_async<U>(self, asc: &AsyncConn) -> Result<Vec<U>, DieselError>
where
U: Send + 'static,
Self: LoadQuery<'static, Conn, U>,
{
asc.run(|conn| self.load(conn).map_err(E::from)).await
asc.run(|conn| self.load(conn)).await
}

async fn get_result_async<U>(self, asc: &AsyncConn) -> Result<U, E>
async fn get_result_async<U>(self, asc: &AsyncConn) -> Result<U, DieselError>
where
U: Send + 'static,
Self: LoadQuery<'static, Conn, U>,
{
asc.run(|conn| self.get_result(conn).map_err(E::from)).await
asc.run(|conn| self.get_result(conn)).await
}

async fn get_results_async<U>(self, asc: &AsyncConn) -> Result<Vec<U>, E>
async fn get_results_async<U>(self, asc: &AsyncConn) -> Result<Vec<U>, DieselError>
where
U: Send + 'static,
Self: LoadQuery<'static, Conn, U>,
{
asc.run(|conn| self.get_results(conn).map_err(E::from))
.await
asc.run(|conn| self.get_results(conn)).await
}

async fn first_async<U>(self, asc: &AsyncConn) -> Result<U, E>
async fn first_async<U>(self, asc: &AsyncConn) -> Result<U, DieselError>
where
U: Send + 'static,
Self: LimitDsl,
Limit<Self>: LoadQuery<'static, Conn, U>,
{
asc.run(|conn| self.first(conn).map_err(E::from)).await
asc.run(|conn| self.first(conn)).await
}
}

#[async_trait]
pub trait AsyncSaveChangesDsl<Conn, AsyncConn, E>
pub trait AsyncSaveChangesDsl<Conn, AsyncConn>
where
Conn: 'static + DieselConnection,
{
async fn save_changes_async<Output>(self, asc: &AsyncConn) -> Result<Output, E>
async fn save_changes_async<Output>(self, asc: &AsyncConn) -> Result<Output, DieselError>
where
Self: Sized,
Conn: diesel::query_dsl::UpdateAndFetchResults<Self, Output>,
Output: Send + 'static;
}

#[async_trait]
impl<T, AsyncConn, Conn, E> AsyncSaveChangesDsl<Conn, AsyncConn, E> for T
impl<T, AsyncConn, Conn> AsyncSaveChangesDsl<Conn, AsyncConn> for T
where
T: 'static + Send + Sync + diesel::SaveChangesDsl<Conn>,
Conn: 'static + DieselConnection,
AsyncConn: Send + Sync + AsyncConnection<Conn, E>,
E: 'static + Send + From<DieselError>,
AsyncConn: Send + Sync + AsyncConnection<Conn>,
{
async fn save_changes_async<Output>(self, asc: &AsyncConn) -> Result<Output, E>
async fn save_changes_async<Output>(self, asc: &AsyncConn) -> Result<Output, DieselError>
where
Conn: diesel::query_dsl::UpdateAndFetchResults<Self, Output>,
Output: Send + 'static,
{
asc.run(|conn| self.save_changes(conn).map_err(E::from))
.await
asc.run(|conn| self.save_changes(conn)).await
}
}
14 changes: 6 additions & 8 deletions src/connection.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
//! An async wrapper around a [`diesel::Connection`].
use crate::{ConnectionError, ConnectionResult};
use async_trait::async_trait;
use diesel::r2d2::R2D2Connection;
use std::sync::{Arc, Mutex, MutexGuard};
@@ -31,31 +30,30 @@ impl<C> Connection<C> {
}

#[async_trait]
impl<Conn> crate::AsyncSimpleConnection<Conn, ConnectionError> for Connection<Conn>
impl<Conn> crate::AsyncSimpleConnection<Conn> for Connection<Conn>
where
Conn: 'static + R2D2Connection,
{
#[inline]
async fn batch_execute_async(&self, query: &str) -> ConnectionResult<()> {
async fn batch_execute_async(&self, query: &str) -> Result<(), diesel::result::Error> {
let diesel_conn = Connection(self.0.clone());
let query = query.to_string();
task::spawn_blocking(move || diesel_conn.inner().batch_execute(&query))
.await
.unwrap() // Propagate panics
.map_err(ConnectionError::from)
}
}

#[async_trait]
impl<Conn> crate::AsyncConnection<Conn, ConnectionError> for Connection<Conn>
impl<Conn> crate::AsyncConnection<Conn> for Connection<Conn>
where
Conn: 'static + R2D2Connection,
Connection<Conn>: crate::AsyncSimpleConnection<Conn, ConnectionError>,
Connection<Conn>: crate::AsyncSimpleConnection<Conn>,
{
type OwnedConnection = Connection<Conn>;

async fn get_owned_connection(&self) -> Result<Self::OwnedConnection, ConnectionError> {
Ok(Connection(self.0.clone()))
async fn get_owned_connection(&self) -> Self::OwnedConnection {
Connection(self.0.clone())
}

fn as_sync_conn(owned: &Self::OwnedConnection) -> MutexGuard<'_, Conn> {