From d9147e26a7bc55a266ff6898be6b6229540679ff Mon Sep 17 00:00:00 2001 From: Rain Date: Fri, 11 Oct 2024 18:46:47 -0700 Subject: [PATCH 1/2] [spr] initial version Created using spr 1.3.6-beta.1 --- src/async_traits.rs | 168 +++++++++++++++++++++++++++++--------------- src/error.rs | 46 ++++++++++++ src/lib.rs | 4 +- tests/test.rs | 28 +++++--- 4 files changed, 179 insertions(+), 67 deletions(-) diff --git a/src/async_traits.rs b/src/async_traits.rs index 38c879a..49e2c26 100644 --- a/src/async_traits.rs +++ b/src/async_traits.rs @@ -1,6 +1,6 @@ //! Async versions of traits for issuing Diesel queries. -use crate::connection::Connection; +use crate::{connection::Connection, error::RunError}; use async_trait::async_trait; use diesel::{ connection::{ @@ -21,7 +21,7 @@ use std::any::Any; use std::future::Future; use std::sync::Arc; use std::sync::MutexGuard; -use tokio::task::spawn_blocking; +use tokio::task::{spawn_blocking, JoinError}; /// An async variant of [`diesel::connection::SimpleConnection`]. #[async_trait] @@ -48,7 +48,7 @@ where Conn: 'static + DieselConnection + R2D2Connection, Self: Send + Sized + 'static, { - async fn ping_async(&mut self) -> diesel::result::QueryResult<()> { + async fn ping_async(&mut self) -> Result<(), RunError> { self.as_async_conn().run(|conn| conn.ping()).await } @@ -75,7 +75,7 @@ where fn as_async_conn(&self) -> &Connection; /// Runs the function `f` in an context where blocking is safe. - async fn run(&self, f: Func) -> Result + async fn run(&self, f: Func) -> Result> where R: Send + 'static, E: Send + 'static, @@ -86,32 +86,31 @@ where } #[doc(hidden)] - async fn run_with_connection(self, f: Func) -> Result + async fn run_with_connection(self, f: Func) -> Result> where R: Send + 'static, E: Send + 'static, Func: FnOnce(&mut Conn) -> Result + Send + 'static, { - spawn_blocking(move || f(&mut *self.as_sync_conn())) - .await - .unwrap() // Propagate panics + handle_spawn_blocking_error(spawn_blocking(move || f(&mut *self.as_sync_conn())).await) } #[doc(hidden)] - async fn run_with_shared_connection(self: &Arc, f: Func) -> Result + async fn run_with_shared_connection( + self: &Arc, + f: Func, + ) -> Result> where R: Send + 'static, E: Send + 'static, Func: FnOnce(&mut Conn) -> Result + Send + 'static, { let conn = self.clone(); - spawn_blocking(move || f(&mut *conn.as_sync_conn())) - .await - .unwrap() // Propagate panics + handle_spawn_blocking_error(spawn_blocking(move || f(&mut *conn.as_sync_conn())).await) } #[doc(hidden)] - async fn transaction_depth(&self) -> Result { + async fn transaction_depth(&self) -> Result> { let conn = self.get_owned_connection(); Self::run_with_connection(conn, |conn| { @@ -131,9 +130,9 @@ where // This method is a wrapper around that call, with validation that // we're actually issuing the BEGIN statement here. #[doc(hidden)] - async fn start_transaction(self: &Arc) -> Result<(), DieselError> { + async fn start_transaction(self: &Arc) -> Result<(), RunError> { if self.transaction_depth().await? != 0 { - return Err(DieselError::AlreadyInTransaction); + return Err(RunError::User(DieselError::AlreadyInTransaction)); } self.run_with_shared_connection(|conn| Conn::TransactionManager::begin_transaction(conn)) .await?; @@ -146,11 +145,11 @@ where // This method is a wrapper around that call, with validation that // we're actually issuing our first SAVEPOINT here. #[doc(hidden)] - async fn add_retry_savepoint(self: &Arc) -> Result<(), DieselError> { + async fn add_retry_savepoint(self: &Arc) -> Result<(), RunError> { match self.transaction_depth().await? { - 0 => return Err(DieselError::NotInTransaction), + 0 => return Err(RunError::User(DieselError::NotInTransaction)), 1 => (), - _ => return Err(DieselError::AlreadyInTransaction), + _ => return Err(RunError::User(DieselError::AlreadyInTransaction)), }; self.run_with_shared_connection(|conn| Conn::TransactionManager::begin_transaction(conn)) @@ -159,14 +158,14 @@ where } #[doc(hidden)] - async fn commit_transaction(self: &Arc) -> Result<(), DieselError> { + async fn commit_transaction(self: &Arc) -> Result<(), RunError> { self.run_with_shared_connection(|conn| Conn::TransactionManager::commit_transaction(conn)) .await?; Ok(()) } #[doc(hidden)] - async fn rollback_transaction(self: &Arc) -> Result<(), DieselError> { + async fn rollback_transaction(self: &Arc) -> Result<(), RunError> { self.run_with_shared_connection(|conn| { Conn::TransactionManager::rollback_transaction(conn) }) @@ -185,10 +184,10 @@ where &'a self, f: Func, retry: RetryFunc, - ) -> Result + ) -> Result> where R: Any + Send + 'static, - Fut: FutureExt> + Send, + Fut: FutureExt>> + Send, Func: (Fn(Connection) -> Fut) + Send + Sync, RetryFut: FutureExt + Send, RetryFunc: Fn() -> RetryFut + Send + Sync, @@ -221,11 +220,13 @@ where #[cfg(feature = "cockroach")] async fn transaction_async_with_retry_inner( &self, - f: &(dyn Fn(Connection) -> BoxFuture<'_, Result, DieselError>> + f: &(dyn Fn( + Connection, + ) -> BoxFuture<'_, Result, RunError>> + Send + Sync), retry: &(dyn Fn() -> BoxFuture<'_, bool> + Send + Sync), - ) -> Result, DieselError> { + ) -> Result, RunError> { // Check out a connection once, and use it for the duration of the // operation. let conn = Arc::new(self.get_owned_connection()); @@ -262,20 +263,26 @@ where // // We're still in the transaction, but we at least // tried to ROLLBACK to our savepoint. - if !retryable_error(&err) || !retry().await { + let retried = match &err { + RunError::User(err) => retryable_error(err) && retry().await, + RunError::RuntimeShutdown => false, + }; + if retried { + // ROLLBACK happened, we want to retry. + continue; + } else { // Bail: ROLLBACK the initial BEGIN statement too. + // In the case of the run let _ = Self::rollback_transaction(&conn).await; return Err(err); } - // ROLLBACK happened, we want to retry. - continue; } // Commit the top-level transaction too. Self::commit_transaction(&conn).await?; return Ok(value); } - Err(user_error) => { + Err(RunError::User(user_error)) => { // The user-level operation failed: ROLLBACK to the retry // savepoint. if let Err(first_rollback_err) = Self::rollback_transaction(&conn).await { @@ -295,10 +302,18 @@ where // If we aren't retrying, ROLLBACK the BEGIN statement too. return match Self::rollback_transaction(&conn).await { - Ok(()) => Err(user_error), + Ok(()) => Err(RunError::User(user_error)), Err(err) => Err(err), }; } + Err(RunError::RuntimeShutdown) => { + // The runtime is shutting down: attempt to ROLLBACK the + // transaction. You might think it's pointless to try and + // do this, but in reality the shutdown timeout for async + // tasks might not have expired. + let _ = Self::rollback_transaction(&conn).await; + return Err(RunError::RuntimeShutdown); + } } } } @@ -306,7 +321,7 @@ where async fn transaction_async(&'a self, f: Func) -> Result where R: Send + 'static, - E: From + Send + 'static, + E: From> + Send + 'static, Fut: Future> + Send, Func: FnOnce(Connection) -> Fut + Send, { @@ -325,9 +340,8 @@ where .boxed() }); - self.transaction_async_inner(f) - .await - .map(|v| *v.downcast::().expect("Should be an 'R' type")) + let v = self.transaction_async_inner(f).await?; + Ok(*v.downcast::().expect("Should be an 'R' type")) } // NOTE: This function intentionally avoids as many generic parameters as possible @@ -340,7 +354,7 @@ where >, ) -> Result, E> where - E: From + Send + 'static, + E: From> + Send + 'static, { // Check out a connection once, and use it for the duration of the // operation. @@ -352,9 +366,14 @@ where // However, it modifies all callsites to instead issue // known-to-be-synchronous operations from an asynchronous context. conn.run_with_shared_connection(|conn| { - Conn::TransactionManager::begin_transaction(conn).map_err(E::from) + Conn::TransactionManager::begin_transaction(conn) + .map_err(|err| E::from(RunError::User(err))) }) - .await?; + .await + .map_err(|err| match err { + RunError::User(err) => err, + RunError::RuntimeShutdown => RunError::RuntimeShutdown.into(), + })?; // TODO: The ideal interface would pass the "async_conn" object to the // underlying function "f" by reference. @@ -371,53 +390,86 @@ where let async_conn = Connection(Self::as_async_conn(&conn).0.clone()); match f(async_conn).await { Ok(value) => { - conn.run_with_shared_connection(|conn| { - Conn::TransactionManager::commit_transaction(conn).map_err(E::from) - }) - .await?; - Ok(value) + match conn + .run_with_shared_connection(|conn| { + Conn::TransactionManager::commit_transaction(conn) + .map_err(|err| E::from(RunError::User(err))) + }) + .await + { + Ok(()) => Ok(value), + // XXX: we should try to roll this back + Err(RunError::User(err)) => Err(err), + Err(RunError::RuntimeShutdown) => Err(RunError::RuntimeShutdown.into()), + } } Err(user_error) => { match conn .run_with_shared_connection(|conn| { - Conn::TransactionManager::rollback_transaction(conn).map_err(E::from) + Conn::TransactionManager::rollback_transaction(conn) + .map_err(|err| E::from(RunError::User(err))) }) .await { Ok(()) => Err(user_error), - Err(err) => Err(err), + Err(RunError::User(err)) => Err(err), + Err(RunError::RuntimeShutdown) => Err(RunError::RuntimeShutdown.into()), } } } } } +fn handle_spawn_blocking_error( + result: Result, JoinError>, +) -> Result> { + match result { + Ok(Ok(v)) => Ok(v), + Ok(Err(err)) => Err(RunError::User(err)), + Err(err) => { + if err.is_cancelled() { + // The only way a spawn_blocking task can be marked cancelled + // is if the runtime started shutting down _before_ + // spawn_blocking was called. + Err(RunError::RuntimeShutdown) + } else if err.is_panic() { + // Propagate panics. + std::panic::panic_any(err.into_panic()); + } else { + // Not possible to reach this as of Tokio 1.40, but maybe in + // future versions. + panic!("unexpected JoinError: {:?}", err); + } + } + } +} + /// An async variant of [`diesel::query_dsl::RunQueryDsl`]. #[async_trait] pub trait AsyncRunQueryDsl where Conn: 'static + DieselConnection, { - async fn execute_async(self, asc: &AsyncConn) -> Result + async fn execute_async(self, asc: &AsyncConn) -> Result> where Self: ExecuteDsl; - async fn load_async(self, asc: &AsyncConn) -> Result, DieselError> + async fn load_async(self, asc: &AsyncConn) -> Result, RunError> where U: Send + 'static, Self: LoadQuery<'static, Conn, U>; - async fn get_result_async(self, asc: &AsyncConn) -> Result + async fn get_result_async(self, asc: &AsyncConn) -> Result> where U: Send + 'static, Self: LoadQuery<'static, Conn, U>; - async fn get_results_async(self, asc: &AsyncConn) -> Result, DieselError> + async fn get_results_async(self, asc: &AsyncConn) -> Result, RunError> where U: Send + 'static, Self: LoadQuery<'static, Conn, U>; - async fn first_async(self, asc: &AsyncConn) -> Result + async fn first_async(self, asc: &AsyncConn) -> Result> where U: Send + 'static, Self: LimitDsl, @@ -431,14 +483,14 @@ where Conn: 'static + DieselConnection, AsyncConn: Send + Sync + AsyncConnection, { - async fn execute_async(self, asc: &AsyncConn) -> Result + async fn execute_async(self, asc: &AsyncConn) -> Result> where Self: ExecuteDsl, { asc.run(|conn| self.execute(conn)).await } - async fn load_async(self, asc: &AsyncConn) -> Result, DieselError> + async fn load_async(self, asc: &AsyncConn) -> Result, RunError> where U: Send + 'static, Self: LoadQuery<'static, Conn, U>, @@ -446,7 +498,7 @@ where asc.run(|conn| self.load(conn)).await } - async fn get_result_async(self, asc: &AsyncConn) -> Result + async fn get_result_async(self, asc: &AsyncConn) -> Result> where U: Send + 'static, Self: LoadQuery<'static, Conn, U>, @@ -454,7 +506,7 @@ where asc.run(|conn| self.get_result(conn)).await } - async fn get_results_async(self, asc: &AsyncConn) -> Result, DieselError> + async fn get_results_async(self, asc: &AsyncConn) -> Result, RunError> where U: Send + 'static, Self: LoadQuery<'static, Conn, U>, @@ -462,7 +514,7 @@ where asc.run(|conn| self.get_results(conn)).await } - async fn first_async(self, asc: &AsyncConn) -> Result + async fn first_async(self, asc: &AsyncConn) -> Result> where U: Send + 'static, Self: LimitDsl, @@ -477,7 +529,10 @@ pub trait AsyncSaveChangesDsl where Conn: 'static + DieselConnection, { - async fn save_changes_async(self, asc: &AsyncConn) -> Result + async fn save_changes_async( + self, + asc: &AsyncConn, + ) -> Result> where Self: Sized, Conn: diesel::query_dsl::UpdateAndFetchResults, @@ -491,7 +546,10 @@ where Conn: 'static + DieselConnection, AsyncConn: Send + Sync + AsyncConnection, { - async fn save_changes_async(self, asc: &AsyncConn) -> Result + async fn save_changes_async( + self, + asc: &AsyncConn, + ) -> Result> where Conn: diesel::query_dsl::UpdateAndFetchResults, Output: Send + 'static, diff --git a/src/error.rs b/src/error.rs index 333c13e..2aa4ef3 100644 --- a/src/error.rs +++ b/src/error.rs @@ -20,6 +20,18 @@ pub enum ConnectionError { #[error("Failed to issue a query: {0}")] Query(#[from] DieselError), + + #[error("runtime shutting down")] + RuntimeShutdown, +} + +impl From> for ConnectionError { + fn from(error: RunError) -> Self { + match error { + RunError::User(e) => ConnectionError::Query(e), + RunError::RuntimeShutdown => ConnectionError::RuntimeShutdown, + } + } } /// Syntactic sugar around a Result returning an [`PoolError`]. @@ -44,6 +56,40 @@ impl OptionalExtension for Result { } } +impl OptionalExtension for Result> { + fn optional(self) -> Result, ConnectionError> { + let self_as_query_result: diesel::QueryResult = match self { + Ok(value) => Ok(value), + Err(RunError::User(error_kind)) => Err(error_kind), + Err(RunError::RuntimeShutdown) => return Err(ConnectionError::RuntimeShutdown), + }; + + self_as_query_result + .optional() + .map_err(ConnectionError::Query) + } +} + +/// Errors encountered while running a function on a connection pool. +#[derive(Error, Debug, Clone, Copy, Eq, PartialEq)] +pub enum RunError { + #[error("error in user code")] + User(#[from] E), + + #[error("runtime shutting down")] + RuntimeShutdown, +} + +impl RunError> { + /// Flatten a nested `RunError`. + pub fn flatten(self) -> RunError { + match self { + RunError::User(inner) => inner, + RunError::RuntimeShutdown => RunError::RuntimeShutdown, + } + } +} + /// Describes an error performing an operation from a connection pool. /// /// This is a superset of [`ConnectionError`] which also may diff --git a/src/lib.rs b/src/lib.rs index e4fbecd..e9aef5c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,4 +16,6 @@ pub use async_traits::{ }; pub use connection::Connection; pub use connection_manager::ConnectionManager; -pub use error::{ConnectionError, ConnectionResult, OptionalExtension, PoolError, PoolResult}; +pub use error::{ + ConnectionError, ConnectionResult, OptionalExtension, PoolError, PoolResult, RunError, +}; diff --git a/tests/test.rs b/tests/test.rs index 31fea49..bcbd507 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -4,9 +4,9 @@ use async_bb8_diesel::{ AsyncConnection, AsyncRunQueryDsl, AsyncSaveChangesDsl, AsyncSimpleConnection, ConnectionError, + OptionalExtension, RunError, }; use crdb_harness::{CockroachInstance, CockroachStarterBuilder}; -use diesel::OptionalExtension; use diesel::{pg::PgConnection, prelude::*}; table! { @@ -208,13 +208,13 @@ async fn test_transaction_automatic_retry_explicit_rollback() { if *count < 2 { eprintln!("test: Manually restarting txn"); - return Err::<(), _>(diesel::result::Error::DatabaseError( + return Err::<(), _>(RunError::User(diesel::result::Error::DatabaseError( diesel::result::DatabaseErrorKind::SerializationFailure, Box::new("restart transaction".to_string()), - )); + ))); } eprintln!("test: Manually rolling back txn"); - return Err(diesel::result::Error::RollbackTransaction); + return Err(RunError::User(diesel::result::Error::RollbackTransaction)); } }, || async { @@ -225,7 +225,10 @@ async fn test_transaction_automatic_retry_explicit_rollback() { .await .expect_err("Transaction should have failed"); - assert_eq!(err, diesel::result::Error::RollbackTransaction); + assert_eq!( + err, + RunError::User(diesel::result::Error::RollbackTransaction) + ); assert_eq!(conn.transaction_depth().await.unwrap(), 0); // The transaction closure should have been attempted twice, but @@ -317,12 +320,12 @@ async fn test_transaction_automatic_retry_does_not_retry_non_retryable_errors() assert_eq!(conn.transaction_depth().await.unwrap(), 0); assert_eq!( conn.transaction_async_with_retry( - |_| async { Err::<(), _>(diesel::result::Error::NotFound) }, + |_| async { Err::<(), _>(RunError::User(diesel::result::Error::NotFound)) }, || async { panic!("Should not attempt to retry this operation") } ) .await .expect_err("Transaction should have failed"), - diesel::result::Error::NotFound, + RunError::User(diesel::result::Error::NotFound), ); assert_eq!(conn.transaction_depth().await.unwrap(), 0); @@ -361,7 +364,10 @@ async fn test_transaction_automatic_retry_nested_transactions_fail() { ) .await .expect_err("Nested transaction should have failed"); - assert_eq!(err, diesel::result::Error::AlreadyInTransaction); + assert_eq!( + err, + RunError::User(diesel::result::Error::AlreadyInTransaction) + ); // We still want to show that control exists within the outer // transaction, so we explicitly return here. @@ -395,9 +401,9 @@ async fn test_transaction_custom_error() { Other, } - impl From for MyError { - fn from(error: diesel::result::Error) -> Self { - MyError::Db(ConnectionError::Query(error)) + impl From> for MyError { + fn from(error: RunError) -> Self { + MyError::Db(ConnectionError::from(error)) } } From 01be5c0c98eddf0041d04ade5a5ba685a570336a Mon Sep 17 00:00:00 2001 From: Rain Date: Mon, 14 Oct 2024 20:05:39 -0700 Subject: [PATCH 2/2] please work Created using spr 1.3.6-beta.1 --- src/async_traits.rs | 146 ++++++++++++++++++-------------------------- src/error.rs | 31 ++++------ tests/test.rs | 26 ++++---- 3 files changed, 84 insertions(+), 119 deletions(-) diff --git a/src/async_traits.rs b/src/async_traits.rs index 49e2c26..3d75775 100644 --- a/src/async_traits.rs +++ b/src/async_traits.rs @@ -48,13 +48,13 @@ where Conn: 'static + DieselConnection + R2D2Connection, Self: Send + Sized + 'static, { - async fn ping_async(&mut self) -> Result<(), RunError> { + async fn ping_async(&mut self) -> Result<(), RunError> { self.as_async_conn().run(|conn| conn.ping()).await } async fn is_broken_async(&mut self) -> bool { self.as_async_conn() - .run(|conn| Ok::(conn.is_broken())) + .run(|conn| Ok::(conn.is_broken())) .await .unwrap() } @@ -75,42 +75,36 @@ where fn as_async_conn(&self) -> &Connection; /// Runs the function `f` in an context where blocking is safe. - async fn run(&self, f: Func) -> Result> + async fn run(&self, f: Func) -> Result where R: Send + 'static, - E: Send + 'static, - Func: FnOnce(&mut Conn) -> Result + Send + 'static, + Func: FnOnce(&mut Conn) -> Result + Send + 'static, { let connection = self.get_owned_connection(); connection.run_with_connection(f).await } #[doc(hidden)] - async fn run_with_connection(self, f: Func) -> Result> + async fn run_with_connection(self, f: Func) -> Result where R: Send + 'static, - E: Send + 'static, - Func: FnOnce(&mut Conn) -> Result + Send + 'static, + Func: FnOnce(&mut Conn) -> Result + Send + 'static, { handle_spawn_blocking_error(spawn_blocking(move || f(&mut *self.as_sync_conn())).await) } #[doc(hidden)] - async fn run_with_shared_connection( - self: &Arc, - f: Func, - ) -> Result> + async fn run_with_shared_connection(self: &Arc, f: Func) -> Result where R: Send + 'static, - E: Send + 'static, - Func: FnOnce(&mut Conn) -> Result + Send + 'static, + Func: FnOnce(&mut Conn) -> Result + Send + 'static, { let conn = self.clone(); handle_spawn_blocking_error(spawn_blocking(move || f(&mut *conn.as_sync_conn())).await) } #[doc(hidden)] - async fn transaction_depth(&self) -> Result> { + async fn transaction_depth(&self) -> Result { let conn = self.get_owned_connection(); Self::run_with_connection(conn, |conn| { @@ -130,9 +124,9 @@ where // This method is a wrapper around that call, with validation that // we're actually issuing the BEGIN statement here. #[doc(hidden)] - async fn start_transaction(self: &Arc) -> Result<(), RunError> { + async fn start_transaction(self: &Arc) -> Result<(), RunError> { if self.transaction_depth().await? != 0 { - return Err(RunError::User(DieselError::AlreadyInTransaction)); + return Err(RunError::DieselError(DieselError::AlreadyInTransaction)); } self.run_with_shared_connection(|conn| Conn::TransactionManager::begin_transaction(conn)) .await?; @@ -145,11 +139,11 @@ where // This method is a wrapper around that call, with validation that // we're actually issuing our first SAVEPOINT here. #[doc(hidden)] - async fn add_retry_savepoint(self: &Arc) -> Result<(), RunError> { + async fn add_retry_savepoint(self: &Arc) -> Result<(), RunError> { match self.transaction_depth().await? { - 0 => return Err(RunError::User(DieselError::NotInTransaction)), + 0 => return Err(RunError::DieselError(DieselError::NotInTransaction)), 1 => (), - _ => return Err(RunError::User(DieselError::AlreadyInTransaction)), + _ => return Err(RunError::DieselError(DieselError::AlreadyInTransaction)), }; self.run_with_shared_connection(|conn| Conn::TransactionManager::begin_transaction(conn)) @@ -158,14 +152,14 @@ where } #[doc(hidden)] - async fn commit_transaction(self: &Arc) -> Result<(), RunError> { + async fn commit_transaction(self: &Arc) -> Result<(), RunError> { self.run_with_shared_connection(|conn| Conn::TransactionManager::commit_transaction(conn)) .await?; Ok(()) } #[doc(hidden)] - async fn rollback_transaction(self: &Arc) -> Result<(), RunError> { + async fn rollback_transaction(self: &Arc) -> Result<(), RunError> { self.run_with_shared_connection(|conn| { Conn::TransactionManager::rollback_transaction(conn) }) @@ -184,10 +178,10 @@ where &'a self, f: Func, retry: RetryFunc, - ) -> Result> + ) -> Result where R: Any + Send + 'static, - Fut: FutureExt>> + Send, + Fut: FutureExt> + Send, Func: (Fn(Connection) -> Fut) + Send + Sync, RetryFut: FutureExt + Send, RetryFunc: Fn() -> RetryFut + Send + Sync, @@ -220,13 +214,11 @@ where #[cfg(feature = "cockroach")] async fn transaction_async_with_retry_inner( &self, - f: &(dyn Fn( - Connection, - ) -> BoxFuture<'_, Result, RunError>> + f: &(dyn Fn(Connection) -> BoxFuture<'_, Result, RunError>> + Send + Sync), retry: &(dyn Fn() -> BoxFuture<'_, bool> + Send + Sync), - ) -> Result, RunError> { + ) -> Result, RunError> { // Check out a connection once, and use it for the duration of the // operation. let conn = Arc::new(self.get_owned_connection()); @@ -264,7 +256,7 @@ where // We're still in the transaction, but we at least // tried to ROLLBACK to our savepoint. let retried = match &err { - RunError::User(err) => retryable_error(err) && retry().await, + RunError::DieselError(err) => retryable_error(err) && retry().await, RunError::RuntimeShutdown => false, }; if retried { @@ -282,7 +274,7 @@ where Self::commit_transaction(&conn).await?; return Ok(value); } - Err(RunError::User(user_error)) => { + Err(RunError::DieselError(user_error)) => { // The user-level operation failed: ROLLBACK to the retry // savepoint. if let Err(first_rollback_err) = Self::rollback_transaction(&conn).await { @@ -302,7 +294,7 @@ where // If we aren't retrying, ROLLBACK the BEGIN statement too. return match Self::rollback_transaction(&conn).await { - Ok(()) => Err(RunError::User(user_error)), + Ok(()) => Err(RunError::DieselError(user_error)), Err(err) => Err(err), }; } @@ -321,7 +313,7 @@ where async fn transaction_async(&'a self, f: Func) -> Result where R: Send + 'static, - E: From> + Send + 'static, + E: From + Send + 'static, Fut: Future> + Send, Func: FnOnce(Connection) -> Fut + Send, { @@ -354,7 +346,7 @@ where >, ) -> Result, E> where - E: From> + Send + 'static, + E: From + Send + 'static, { // Check out a connection once, and use it for the duration of the // operation. @@ -365,15 +357,8 @@ where // // However, it modifies all callsites to instead issue // known-to-be-synchronous operations from an asynchronous context. - conn.run_with_shared_connection(|conn| { - Conn::TransactionManager::begin_transaction(conn) - .map_err(|err| E::from(RunError::User(err))) - }) - .await - .map_err(|err| match err { - RunError::User(err) => err, - RunError::RuntimeShutdown => RunError::RuntimeShutdown.into(), - })?; + conn.run_with_shared_connection(|conn| Conn::TransactionManager::begin_transaction(conn)) + .await?; // TODO: The ideal interface would pass the "async_conn" object to the // underlying function "f" by reference. @@ -390,42 +375,29 @@ where let async_conn = Connection(Self::as_async_conn(&conn).0.clone()); match f(async_conn).await { Ok(value) => { - match conn - .run_with_shared_connection(|conn| { - Conn::TransactionManager::commit_transaction(conn) - .map_err(|err| E::from(RunError::User(err))) - }) - .await - { - Ok(()) => Ok(value), - // XXX: we should try to roll this back - Err(RunError::User(err)) => Err(err), - Err(RunError::RuntimeShutdown) => Err(RunError::RuntimeShutdown.into()), - } + conn.run_with_shared_connection(|conn| { + Conn::TransactionManager::commit_transaction(conn) + }) + .await?; + Ok(value) } Err(user_error) => { - match conn - .run_with_shared_connection(|conn| { - Conn::TransactionManager::rollback_transaction(conn) - .map_err(|err| E::from(RunError::User(err))) - }) - .await - { - Ok(()) => Err(user_error), - Err(RunError::User(err)) => Err(err), - Err(RunError::RuntimeShutdown) => Err(RunError::RuntimeShutdown.into()), - } + conn.run_with_shared_connection(|conn| { + Conn::TransactionManager::rollback_transaction(conn) + }) + .await?; + Err(user_error) } } } } -fn handle_spawn_blocking_error( - result: Result, JoinError>, -) -> Result> { +fn handle_spawn_blocking_error( + result: Result, JoinError>, +) -> Result { match result { Ok(Ok(v)) => Ok(v), - Ok(Err(err)) => Err(RunError::User(err)), + Ok(Err(err)) => Err(RunError::DieselError(err)), Err(err) => { if err.is_cancelled() { // The only way a spawn_blocking task can be marked cancelled @@ -438,7 +410,11 @@ fn handle_spawn_blocking_error( } else { // Not possible to reach this as of Tokio 1.40, but maybe in // future versions. - panic!("unexpected JoinError: {:?}", err); + panic!( + "unexpected JoinError, did a new version of Tokio add \ + a source other than panics and cancellations? {:?}", + err + ); } } } @@ -450,26 +426,26 @@ pub trait AsyncRunQueryDsl where Conn: 'static + DieselConnection, { - async fn execute_async(self, asc: &AsyncConn) -> Result> + async fn execute_async(self, asc: &AsyncConn) -> Result where Self: ExecuteDsl; - async fn load_async(self, asc: &AsyncConn) -> Result, RunError> + async fn load_async(self, asc: &AsyncConn) -> Result, RunError> where U: Send + 'static, Self: LoadQuery<'static, Conn, U>; - async fn get_result_async(self, asc: &AsyncConn) -> Result> + async fn get_result_async(self, asc: &AsyncConn) -> Result where U: Send + 'static, Self: LoadQuery<'static, Conn, U>; - async fn get_results_async(self, asc: &AsyncConn) -> Result, RunError> + async fn get_results_async(self, asc: &AsyncConn) -> Result, RunError> where U: Send + 'static, Self: LoadQuery<'static, Conn, U>; - async fn first_async(self, asc: &AsyncConn) -> Result> + async fn first_async(self, asc: &AsyncConn) -> Result where U: Send + 'static, Self: LimitDsl, @@ -483,14 +459,14 @@ where Conn: 'static + DieselConnection, AsyncConn: Send + Sync + AsyncConnection, { - async fn execute_async(self, asc: &AsyncConn) -> Result> + async fn execute_async(self, asc: &AsyncConn) -> Result where Self: ExecuteDsl, { asc.run(|conn| self.execute(conn)).await } - async fn load_async(self, asc: &AsyncConn) -> Result, RunError> + async fn load_async(self, asc: &AsyncConn) -> Result, RunError> where U: Send + 'static, Self: LoadQuery<'static, Conn, U>, @@ -498,7 +474,7 @@ where asc.run(|conn| self.load(conn)).await } - async fn get_result_async(self, asc: &AsyncConn) -> Result> + async fn get_result_async(self, asc: &AsyncConn) -> Result where U: Send + 'static, Self: LoadQuery<'static, Conn, U>, @@ -506,7 +482,7 @@ where asc.run(|conn| self.get_result(conn)).await } - async fn get_results_async(self, asc: &AsyncConn) -> Result, RunError> + async fn get_results_async(self, asc: &AsyncConn) -> Result, RunError> where U: Send + 'static, Self: LoadQuery<'static, Conn, U>, @@ -514,7 +490,7 @@ where asc.run(|conn| self.get_results(conn)).await } - async fn first_async(self, asc: &AsyncConn) -> Result> + async fn first_async(self, asc: &AsyncConn) -> Result where U: Send + 'static, Self: LimitDsl, @@ -529,10 +505,7 @@ pub trait AsyncSaveChangesDsl where Conn: 'static + DieselConnection, { - async fn save_changes_async( - self, - asc: &AsyncConn, - ) -> Result> + async fn save_changes_async(self, asc: &AsyncConn) -> Result where Self: Sized, Conn: diesel::query_dsl::UpdateAndFetchResults, @@ -546,10 +519,7 @@ where Conn: 'static + DieselConnection, AsyncConn: Send + Sync + AsyncConnection, { - async fn save_changes_async( - self, - asc: &AsyncConn, - ) -> Result> + async fn save_changes_async(self, asc: &AsyncConn) -> Result where Conn: diesel::query_dsl::UpdateAndFetchResults, Output: Send + 'static, diff --git a/src/error.rs b/src/error.rs index 2aa4ef3..184df43 100644 --- a/src/error.rs +++ b/src/error.rs @@ -25,10 +25,10 @@ pub enum ConnectionError { RuntimeShutdown, } -impl From> for ConnectionError { - fn from(error: RunError) -> Self { +impl From for ConnectionError { + fn from(error: RunError) -> Self { match error { - RunError::User(e) => ConnectionError::Query(e), + RunError::DieselError(e) => ConnectionError::Query(e), RunError::RuntimeShutdown => ConnectionError::RuntimeShutdown, } } @@ -56,11 +56,11 @@ impl OptionalExtension for Result { } } -impl OptionalExtension for Result> { +impl OptionalExtension for Result { fn optional(self) -> Result, ConnectionError> { let self_as_query_result: diesel::QueryResult = match self { Ok(value) => Ok(value), - Err(RunError::User(error_kind)) => Err(error_kind), + Err(RunError::DieselError(error_kind)) => Err(error_kind), Err(RunError::RuntimeShutdown) => return Err(ConnectionError::RuntimeShutdown), }; @@ -70,26 +70,17 @@ impl OptionalExtension for Result> { } } -/// Errors encountered while running a function on a connection pool. -#[derive(Error, Debug, Clone, Copy, Eq, PartialEq)] -pub enum RunError { - #[error("error in user code")] - User(#[from] E), +/// An error encountered while running a function on a connection pool. +#[derive(Error, Debug, PartialEq)] +pub enum RunError { + /// There was a Diesel error running the query. + #[error(transparent)] + DieselError(#[from] DieselError), #[error("runtime shutting down")] RuntimeShutdown, } -impl RunError> { - /// Flatten a nested `RunError`. - pub fn flatten(self) -> RunError { - match self { - RunError::User(inner) => inner, - RunError::RuntimeShutdown => RunError::RuntimeShutdown, - } - } -} - /// Describes an error performing an operation from a connection pool. /// /// This is a superset of [`ConnectionError`] which also may diff --git a/tests/test.rs b/tests/test.rs index bcbd507..21ff106 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -208,13 +208,17 @@ async fn test_transaction_automatic_retry_explicit_rollback() { if *count < 2 { eprintln!("test: Manually restarting txn"); - return Err::<(), _>(RunError::User(diesel::result::Error::DatabaseError( - diesel::result::DatabaseErrorKind::SerializationFailure, - Box::new("restart transaction".to_string()), - ))); + return Err::<(), _>(RunError::DieselError( + diesel::result::Error::DatabaseError( + diesel::result::DatabaseErrorKind::SerializationFailure, + Box::new("restart transaction".to_string()), + ), + )); } eprintln!("test: Manually rolling back txn"); - return Err(RunError::User(diesel::result::Error::RollbackTransaction)); + return Err(RunError::DieselError( + diesel::result::Error::RollbackTransaction, + )); } }, || async { @@ -227,7 +231,7 @@ async fn test_transaction_automatic_retry_explicit_rollback() { assert_eq!( err, - RunError::User(diesel::result::Error::RollbackTransaction) + RunError::DieselError(diesel::result::Error::RollbackTransaction) ); assert_eq!(conn.transaction_depth().await.unwrap(), 0); @@ -320,12 +324,12 @@ async fn test_transaction_automatic_retry_does_not_retry_non_retryable_errors() assert_eq!(conn.transaction_depth().await.unwrap(), 0); assert_eq!( conn.transaction_async_with_retry( - |_| async { Err::<(), _>(RunError::User(diesel::result::Error::NotFound)) }, + |_| async { Err::<(), _>(RunError::DieselError(diesel::result::Error::NotFound)) }, || async { panic!("Should not attempt to retry this operation") } ) .await .expect_err("Transaction should have failed"), - RunError::User(diesel::result::Error::NotFound), + RunError::DieselError(diesel::result::Error::NotFound), ); assert_eq!(conn.transaction_depth().await.unwrap(), 0); @@ -366,7 +370,7 @@ async fn test_transaction_automatic_retry_nested_transactions_fail() { .expect_err("Nested transaction should have failed"); assert_eq!( err, - RunError::User(diesel::result::Error::AlreadyInTransaction) + RunError::DieselError(diesel::result::Error::AlreadyInTransaction) ); // We still want to show that control exists within the outer @@ -401,8 +405,8 @@ async fn test_transaction_custom_error() { Other, } - impl From> for MyError { - fn from(error: RunError) -> Self { + impl From for MyError { + fn from(error: RunError) -> Self { MyError::Db(ConnectionError::from(error)) } }