Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flatten error handling, reducing generics #54

Merged
merged 2 commits into from
Oct 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/customize_connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ impl bb8::CustomizeConnection<DieselPgConn, ConnectionError> for ConnectionCusto
connection
.batch_execute_async("please execute some raw sql for me")
.await
.map_err(ConnectionError::from)
}
}

Expand Down
5 changes: 2 additions & 3 deletions examples/usage.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use async_bb8_diesel::{
AsyncConnection, AsyncRunQueryDsl, AsyncSaveChangesDsl, ConnectionError, OptionalExtension,
};
use async_bb8_diesel::{AsyncConnection, AsyncRunQueryDsl, AsyncSaveChangesDsl, ConnectionError};
use diesel::OptionalExtension;
use diesel::{pg::PgConnection, prelude::*};

table! {
Expand Down
85 changes: 38 additions & 47 deletions src/async_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,42 +18,37 @@ use tokio::task::spawn_blocking;

/// An async variant of [`diesel::connection::SimpleConnection`].
#[async_trait]
pub trait AsyncSimpleConnection<Conn, ConnErr>
pub trait AsyncSimpleConnection<Conn>
where
Conn: 'static + SimpleConnection,
{
async fn batch_execute_async(&self, query: &str) -> Result<(), ConnErr>;
async fn batch_execute_async(&self, query: &str) -> Result<(), DieselError>;
}

/// An async variant of [`diesel::connection::Connection`].
#[async_trait]
pub trait AsyncConnection<Conn, ConnErr>: AsyncSimpleConnection<Conn, ConnErr>
pub trait AsyncConnection<Conn>: AsyncSimpleConnection<Conn>
where
Conn: 'static + DieselConnection,
ConnErr: From<DieselError> + Send + 'static,
Self: Send,
{
type OwnedConnection: Sync + Send + 'static;

#[doc(hidden)]
async fn get_owned_connection(&self) -> Result<Self::OwnedConnection, ConnErr>;
async fn get_owned_connection(&self) -> Self::OwnedConnection;
#[doc(hidden)]
fn as_sync_conn(owned: &Self::OwnedConnection) -> MutexGuard<'_, Conn>;
#[doc(hidden)]
fn as_async_conn(owned: &Self::OwnedConnection) -> &SingleConnection<Conn>;

/// Runs the function `f` in an context where blocking is safe.
///
/// Any error may be propagated through `f`, as long as that
/// error type may be constructed from `ConnErr` (as that error
/// type may also be generated).
async fn run<R, E, Func>(&self, f: Func) -> Result<R, E>
where
R: Send + 'static,
E: From<ConnErr> + Send + 'static,
E: Send + 'static,
Func: FnOnce(&mut Conn) -> Result<R, E> + Send + 'static,
{
let connection = self.get_owned_connection().await?;
let connection = self.get_owned_connection().await;
Self::run_with_connection(connection, f).await
}

Expand All @@ -64,7 +59,7 @@ where
) -> Result<R, E>
where
R: Send + 'static,
E: From<ConnErr> + Send + 'static,
E: Send + 'static,
Func: FnOnce(&mut Conn) -> Result<R, E> + Send + 'static,
{
spawn_blocking(move || f(&mut *Self::as_sync_conn(&connection)))
Expand All @@ -79,7 +74,7 @@ where
) -> Result<R, E>
where
R: Send + 'static,
E: From<ConnErr> + Send + 'static,
E: Send + 'static,
Func: FnOnce(&mut Conn) -> Result<R, E> + Send + 'static,
{
spawn_blocking(move || f(&mut *Self::as_sync_conn(&connection)))
Expand All @@ -90,7 +85,7 @@ where
async fn transaction<R, E, Func>(&self, f: Func) -> Result<R, E>
where
R: Send + 'static,
E: From<DieselError> + From<ConnErr> + Send + 'static,
E: From<DieselError> + Send + 'static,
Func: FnOnce(&mut Conn) -> Result<R, E> + Send + 'static,
{
self.run(|conn| conn.transaction(|c| f(c))).await
Expand All @@ -99,21 +94,21 @@ where
async fn transaction_async<R, E, Func, Fut, 'a>(&'a self, f: Func) -> Result<R, E>
where
R: Send + 'static,
E: From<DieselError> + From<ConnErr> + Send,
E: From<DieselError> + Send + 'static,
Fut: Future<Output = Result<R, E>> + Send,
Func: FnOnce(SingleConnection<Conn>) -> Fut + Send,
{
// Check out a connection once, and use it for the duration of the
// operation.
let conn = Arc::new(self.get_owned_connection().await?);
let conn = Arc::new(self.get_owned_connection().await);

// This function mimics the implementation of:
// https://docs.diesel.rs/master/diesel/connection/trait.TransactionManager.html#method.transaction
//
// However, it modifies all callsites to instead issue
// known-to-be-synchronous operations from an asynchronous context.
Self::run_with_shared_connection(conn.clone(), |conn| {
Conn::TransactionManager::begin_transaction(conn).map_err(ConnErr::from)
Conn::TransactionManager::begin_transaction(conn).map_err(E::from)
})
.await?;

Expand All @@ -133,14 +128,14 @@ where
match f(async_conn).await {
Ok(value) => {
Self::run_with_shared_connection(conn.clone(), |conn| {
Conn::TransactionManager::commit_transaction(conn).map_err(ConnErr::from)
Conn::TransactionManager::commit_transaction(conn).map_err(E::from)
})
.await?;
Ok(value)
}
Err(user_error) => {
match Self::run_with_shared_connection(conn.clone(), |conn| {
Conn::TransactionManager::rollback_transaction(conn).map_err(ConnErr::from)
Conn::TransactionManager::rollback_transaction(conn).map_err(E::from)
})
.await
{
Expand All @@ -154,112 +149,108 @@ where

/// An async variant of [`diesel::query_dsl::RunQueryDsl`].
#[async_trait]
pub trait AsyncRunQueryDsl<Conn, AsyncConn, E>
pub trait AsyncRunQueryDsl<Conn, AsyncConn>
where
Conn: 'static + DieselConnection,
{
async fn execute_async(self, asc: &AsyncConn) -> Result<usize, E>
async fn execute_async(self, asc: &AsyncConn) -> Result<usize, DieselError>
where
Self: ExecuteDsl<Conn>;

async fn load_async<U>(self, asc: &AsyncConn) -> Result<Vec<U>, E>
async fn load_async<U>(self, asc: &AsyncConn) -> Result<Vec<U>, DieselError>
where
U: Send + 'static,
Self: LoadQuery<'static, Conn, U>;

async fn get_result_async<U>(self, asc: &AsyncConn) -> Result<U, E>
async fn get_result_async<U>(self, asc: &AsyncConn) -> Result<U, DieselError>
where
U: Send + 'static,
Self: LoadQuery<'static, Conn, U>;

async fn get_results_async<U>(self, asc: &AsyncConn) -> Result<Vec<U>, E>
async fn get_results_async<U>(self, asc: &AsyncConn) -> Result<Vec<U>, DieselError>
where
U: Send + 'static,
Self: LoadQuery<'static, Conn, U>;

async fn first_async<U>(self, asc: &AsyncConn) -> Result<U, E>
async fn first_async<U>(self, asc: &AsyncConn) -> Result<U, DieselError>
where
U: Send + 'static,
Self: LimitDsl,
Limit<Self>: LoadQuery<'static, Conn, U>;
}

#[async_trait]
impl<T, AsyncConn, Conn, E> AsyncRunQueryDsl<Conn, AsyncConn, E> for T
impl<T, AsyncConn, Conn> AsyncRunQueryDsl<Conn, AsyncConn> for T
where
T: 'static + Send + RunQueryDsl<Conn>,
Conn: 'static + DieselConnection,
AsyncConn: Send + Sync + AsyncConnection<Conn, E>,
E: From<DieselError> + Send + 'static,
AsyncConn: Send + Sync + AsyncConnection<Conn>,
{
async fn execute_async(self, asc: &AsyncConn) -> Result<usize, E>
async fn execute_async(self, asc: &AsyncConn) -> Result<usize, DieselError>
where
Self: ExecuteDsl<Conn>,
{
asc.run(|conn| self.execute(conn).map_err(E::from)).await
asc.run(|conn| self.execute(conn)).await
}

async fn load_async<U>(self, asc: &AsyncConn) -> Result<Vec<U>, E>
async fn load_async<U>(self, asc: &AsyncConn) -> Result<Vec<U>, DieselError>
where
U: Send + 'static,
Self: LoadQuery<'static, Conn, U>,
{
asc.run(|conn| self.load(conn).map_err(E::from)).await
asc.run(|conn| self.load(conn)).await
}

async fn get_result_async<U>(self, asc: &AsyncConn) -> Result<U, E>
async fn get_result_async<U>(self, asc: &AsyncConn) -> Result<U, DieselError>
where
U: Send + 'static,
Self: LoadQuery<'static, Conn, U>,
{
asc.run(|conn| self.get_result(conn).map_err(E::from)).await
asc.run(|conn| self.get_result(conn)).await
}

async fn get_results_async<U>(self, asc: &AsyncConn) -> Result<Vec<U>, E>
async fn get_results_async<U>(self, asc: &AsyncConn) -> Result<Vec<U>, DieselError>
where
U: Send + 'static,
Self: LoadQuery<'static, Conn, U>,
{
asc.run(|conn| self.get_results(conn).map_err(E::from))
.await
asc.run(|conn| self.get_results(conn)).await
}

async fn first_async<U>(self, asc: &AsyncConn) -> Result<U, E>
async fn first_async<U>(self, asc: &AsyncConn) -> Result<U, DieselError>
where
U: Send + 'static,
Self: LimitDsl,
Limit<Self>: LoadQuery<'static, Conn, U>,
{
asc.run(|conn| self.first(conn).map_err(E::from)).await
asc.run(|conn| self.first(conn)).await
}
}

#[async_trait]
pub trait AsyncSaveChangesDsl<Conn, AsyncConn, E>
pub trait AsyncSaveChangesDsl<Conn, AsyncConn>
where
Conn: 'static + DieselConnection,
{
async fn save_changes_async<Output>(self, asc: &AsyncConn) -> Result<Output, E>
async fn save_changes_async<Output>(self, asc: &AsyncConn) -> Result<Output, DieselError>
where
Self: Sized,
Conn: diesel::query_dsl::UpdateAndFetchResults<Self, Output>,
Output: Send + 'static;
}

#[async_trait]
impl<T, AsyncConn, Conn, E> AsyncSaveChangesDsl<Conn, AsyncConn, E> for T
impl<T, AsyncConn, Conn> AsyncSaveChangesDsl<Conn, AsyncConn> for T
where
T: 'static + Send + Sync + diesel::SaveChangesDsl<Conn>,
Conn: 'static + DieselConnection,
AsyncConn: Send + Sync + AsyncConnection<Conn, E>,
E: 'static + Send + From<DieselError>,
AsyncConn: Send + Sync + AsyncConnection<Conn>,
{
async fn save_changes_async<Output>(self, asc: &AsyncConn) -> Result<Output, E>
async fn save_changes_async<Output>(self, asc: &AsyncConn) -> Result<Output, DieselError>
where
Conn: diesel::query_dsl::UpdateAndFetchResults<Self, Output>,
Output: Send + 'static,
{
asc.run(|conn| self.save_changes(conn).map_err(E::from))
.await
asc.run(|conn| self.save_changes(conn)).await
}
}
14 changes: 6 additions & 8 deletions src/connection.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
//! An async wrapper around a [`diesel::Connection`].

use crate::{ConnectionError, ConnectionResult};
use async_trait::async_trait;
use diesel::r2d2::R2D2Connection;
use std::sync::{Arc, Mutex, MutexGuard};
Expand Down Expand Up @@ -31,31 +30,30 @@ impl<C> Connection<C> {
}

#[async_trait]
impl<Conn> crate::AsyncSimpleConnection<Conn, ConnectionError> for Connection<Conn>
impl<Conn> crate::AsyncSimpleConnection<Conn> for Connection<Conn>
where
Conn: 'static + R2D2Connection,
{
#[inline]
async fn batch_execute_async(&self, query: &str) -> ConnectionResult<()> {
async fn batch_execute_async(&self, query: &str) -> Result<(), diesel::result::Error> {
let diesel_conn = Connection(self.0.clone());
let query = query.to_string();
task::spawn_blocking(move || diesel_conn.inner().batch_execute(&query))
.await
.unwrap() // Propagate panics
.map_err(ConnectionError::from)
}
}

#[async_trait]
impl<Conn> crate::AsyncConnection<Conn, ConnectionError> for Connection<Conn>
impl<Conn> crate::AsyncConnection<Conn> for Connection<Conn>
where
Conn: 'static + R2D2Connection,
Connection<Conn>: crate::AsyncSimpleConnection<Conn, ConnectionError>,
Connection<Conn>: crate::AsyncSimpleConnection<Conn>,
{
type OwnedConnection = Connection<Conn>;

async fn get_owned_connection(&self) -> Result<Self::OwnedConnection, ConnectionError> {
Ok(Connection(self.0.clone()))
async fn get_owned_connection(&self) -> Self::OwnedConnection {
Connection(self.0.clone())
}

fn as_sync_conn(owned: &Self::OwnedConnection) -> MutexGuard<'_, Conn> {
Expand Down