Skip to content

Commit 1446f7e

Browse files
authored
Flatten error handling, reducing generics (#54)
* Simplify a lot of the error propagation * Flatten error types
1 parent da04c08 commit 1446f7e

File tree

4 files changed

+47
-58
lines changed

4 files changed

+47
-58
lines changed

examples/customize_connection.rs

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ impl bb8::CustomizeConnection<DieselPgConn, ConnectionError> for ConnectionCusto
1515
connection
1616
.batch_execute_async("please execute some raw sql for me")
1717
.await
18+
.map_err(ConnectionError::from)
1819
}
1920
}
2021

examples/usage.rs

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
use async_bb8_diesel::{
2-
AsyncConnection, AsyncRunQueryDsl, AsyncSaveChangesDsl, ConnectionError, OptionalExtension,
3-
};
1+
use async_bb8_diesel::{AsyncConnection, AsyncRunQueryDsl, AsyncSaveChangesDsl, ConnectionError};
2+
use diesel::OptionalExtension;
43
use diesel::{pg::PgConnection, prelude::*};
54

65
table! {

src/async_traits.rs

+38-47
Original file line numberDiff line numberDiff line change
@@ -18,42 +18,37 @@ use tokio::task::spawn_blocking;
1818

1919
/// An async variant of [`diesel::connection::SimpleConnection`].
2020
#[async_trait]
21-
pub trait AsyncSimpleConnection<Conn, ConnErr>
21+
pub trait AsyncSimpleConnection<Conn>
2222
where
2323
Conn: 'static + SimpleConnection,
2424
{
25-
async fn batch_execute_async(&self, query: &str) -> Result<(), ConnErr>;
25+
async fn batch_execute_async(&self, query: &str) -> Result<(), DieselError>;
2626
}
2727

2828
/// An async variant of [`diesel::connection::Connection`].
2929
#[async_trait]
30-
pub trait AsyncConnection<Conn, ConnErr>: AsyncSimpleConnection<Conn, ConnErr>
30+
pub trait AsyncConnection<Conn>: AsyncSimpleConnection<Conn>
3131
where
3232
Conn: 'static + DieselConnection,
33-
ConnErr: From<DieselError> + Send + 'static,
3433
Self: Send,
3534
{
3635
type OwnedConnection: Sync + Send + 'static;
3736

3837
#[doc(hidden)]
39-
async fn get_owned_connection(&self) -> Result<Self::OwnedConnection, ConnErr>;
38+
async fn get_owned_connection(&self) -> Self::OwnedConnection;
4039
#[doc(hidden)]
4140
fn as_sync_conn(owned: &Self::OwnedConnection) -> MutexGuard<'_, Conn>;
4241
#[doc(hidden)]
4342
fn as_async_conn(owned: &Self::OwnedConnection) -> &SingleConnection<Conn>;
4443

4544
/// Runs the function `f` in an context where blocking is safe.
46-
///
47-
/// Any error may be propagated through `f`, as long as that
48-
/// error type may be constructed from `ConnErr` (as that error
49-
/// type may also be generated).
5045
async fn run<R, E, Func>(&self, f: Func) -> Result<R, E>
5146
where
5247
R: Send + 'static,
53-
E: From<ConnErr> + Send + 'static,
48+
E: Send + 'static,
5449
Func: FnOnce(&mut Conn) -> Result<R, E> + Send + 'static,
5550
{
56-
let connection = self.get_owned_connection().await?;
51+
let connection = self.get_owned_connection().await;
5752
Self::run_with_connection(connection, f).await
5853
}
5954

@@ -64,7 +59,7 @@ where
6459
) -> Result<R, E>
6560
where
6661
R: Send + 'static,
67-
E: From<ConnErr> + Send + 'static,
62+
E: Send + 'static,
6863
Func: FnOnce(&mut Conn) -> Result<R, E> + Send + 'static,
6964
{
7065
spawn_blocking(move || f(&mut *Self::as_sync_conn(&connection)))
@@ -79,7 +74,7 @@ where
7974
) -> Result<R, E>
8075
where
8176
R: Send + 'static,
82-
E: From<ConnErr> + Send + 'static,
77+
E: Send + 'static,
8378
Func: FnOnce(&mut Conn) -> Result<R, E> + Send + 'static,
8479
{
8580
spawn_blocking(move || f(&mut *Self::as_sync_conn(&connection)))
@@ -90,7 +85,7 @@ where
9085
async fn transaction<R, E, Func>(&self, f: Func) -> Result<R, E>
9186
where
9287
R: Send + 'static,
93-
E: From<DieselError> + From<ConnErr> + Send + 'static,
88+
E: From<DieselError> + Send + 'static,
9489
Func: FnOnce(&mut Conn) -> Result<R, E> + Send + 'static,
9590
{
9691
self.run(|conn| conn.transaction(|c| f(c))).await
@@ -99,21 +94,21 @@ where
9994
async fn transaction_async<R, E, Func, Fut, 'a>(&'a self, f: Func) -> Result<R, E>
10095
where
10196
R: Send + 'static,
102-
E: From<DieselError> + From<ConnErr> + Send,
97+
E: From<DieselError> + Send + 'static,
10398
Fut: Future<Output = Result<R, E>> + Send,
10499
Func: FnOnce(SingleConnection<Conn>) -> Fut + Send,
105100
{
106101
// Check out a connection once, and use it for the duration of the
107102
// operation.
108-
let conn = Arc::new(self.get_owned_connection().await?);
103+
let conn = Arc::new(self.get_owned_connection().await);
109104

110105
// This function mimics the implementation of:
111106
// https://docs.diesel.rs/master/diesel/connection/trait.TransactionManager.html#method.transaction
112107
//
113108
// However, it modifies all callsites to instead issue
114109
// known-to-be-synchronous operations from an asynchronous context.
115110
Self::run_with_shared_connection(conn.clone(), |conn| {
116-
Conn::TransactionManager::begin_transaction(conn).map_err(ConnErr::from)
111+
Conn::TransactionManager::begin_transaction(conn).map_err(E::from)
117112
})
118113
.await?;
119114

@@ -133,14 +128,14 @@ where
133128
match f(async_conn).await {
134129
Ok(value) => {
135130
Self::run_with_shared_connection(conn.clone(), |conn| {
136-
Conn::TransactionManager::commit_transaction(conn).map_err(ConnErr::from)
131+
Conn::TransactionManager::commit_transaction(conn).map_err(E::from)
137132
})
138133
.await?;
139134
Ok(value)
140135
}
141136
Err(user_error) => {
142137
match Self::run_with_shared_connection(conn.clone(), |conn| {
143-
Conn::TransactionManager::rollback_transaction(conn).map_err(ConnErr::from)
138+
Conn::TransactionManager::rollback_transaction(conn).map_err(E::from)
144139
})
145140
.await
146141
{
@@ -154,112 +149,108 @@ where
154149

155150
/// An async variant of [`diesel::query_dsl::RunQueryDsl`].
156151
#[async_trait]
157-
pub trait AsyncRunQueryDsl<Conn, AsyncConn, E>
152+
pub trait AsyncRunQueryDsl<Conn, AsyncConn>
158153
where
159154
Conn: 'static + DieselConnection,
160155
{
161-
async fn execute_async(self, asc: &AsyncConn) -> Result<usize, E>
156+
async fn execute_async(self, asc: &AsyncConn) -> Result<usize, DieselError>
162157
where
163158
Self: ExecuteDsl<Conn>;
164159

165-
async fn load_async<U>(self, asc: &AsyncConn) -> Result<Vec<U>, E>
160+
async fn load_async<U>(self, asc: &AsyncConn) -> Result<Vec<U>, DieselError>
166161
where
167162
U: Send + 'static,
168163
Self: LoadQuery<'static, Conn, U>;
169164

170-
async fn get_result_async<U>(self, asc: &AsyncConn) -> Result<U, E>
165+
async fn get_result_async<U>(self, asc: &AsyncConn) -> Result<U, DieselError>
171166
where
172167
U: Send + 'static,
173168
Self: LoadQuery<'static, Conn, U>;
174169

175-
async fn get_results_async<U>(self, asc: &AsyncConn) -> Result<Vec<U>, E>
170+
async fn get_results_async<U>(self, asc: &AsyncConn) -> Result<Vec<U>, DieselError>
176171
where
177172
U: Send + 'static,
178173
Self: LoadQuery<'static, Conn, U>;
179174

180-
async fn first_async<U>(self, asc: &AsyncConn) -> Result<U, E>
175+
async fn first_async<U>(self, asc: &AsyncConn) -> Result<U, DieselError>
181176
where
182177
U: Send + 'static,
183178
Self: LimitDsl,
184179
Limit<Self>: LoadQuery<'static, Conn, U>;
185180
}
186181

187182
#[async_trait]
188-
impl<T, AsyncConn, Conn, E> AsyncRunQueryDsl<Conn, AsyncConn, E> for T
183+
impl<T, AsyncConn, Conn> AsyncRunQueryDsl<Conn, AsyncConn> for T
189184
where
190185
T: 'static + Send + RunQueryDsl<Conn>,
191186
Conn: 'static + DieselConnection,
192-
AsyncConn: Send + Sync + AsyncConnection<Conn, E>,
193-
E: From<DieselError> + Send + 'static,
187+
AsyncConn: Send + Sync + AsyncConnection<Conn>,
194188
{
195-
async fn execute_async(self, asc: &AsyncConn) -> Result<usize, E>
189+
async fn execute_async(self, asc: &AsyncConn) -> Result<usize, DieselError>
196190
where
197191
Self: ExecuteDsl<Conn>,
198192
{
199-
asc.run(|conn| self.execute(conn).map_err(E::from)).await
193+
asc.run(|conn| self.execute(conn)).await
200194
}
201195

202-
async fn load_async<U>(self, asc: &AsyncConn) -> Result<Vec<U>, E>
196+
async fn load_async<U>(self, asc: &AsyncConn) -> Result<Vec<U>, DieselError>
203197
where
204198
U: Send + 'static,
205199
Self: LoadQuery<'static, Conn, U>,
206200
{
207-
asc.run(|conn| self.load(conn).map_err(E::from)).await
201+
asc.run(|conn| self.load(conn)).await
208202
}
209203

210-
async fn get_result_async<U>(self, asc: &AsyncConn) -> Result<U, E>
204+
async fn get_result_async<U>(self, asc: &AsyncConn) -> Result<U, DieselError>
211205
where
212206
U: Send + 'static,
213207
Self: LoadQuery<'static, Conn, U>,
214208
{
215-
asc.run(|conn| self.get_result(conn).map_err(E::from)).await
209+
asc.run(|conn| self.get_result(conn)).await
216210
}
217211

218-
async fn get_results_async<U>(self, asc: &AsyncConn) -> Result<Vec<U>, E>
212+
async fn get_results_async<U>(self, asc: &AsyncConn) -> Result<Vec<U>, DieselError>
219213
where
220214
U: Send + 'static,
221215
Self: LoadQuery<'static, Conn, U>,
222216
{
223-
asc.run(|conn| self.get_results(conn).map_err(E::from))
224-
.await
217+
asc.run(|conn| self.get_results(conn)).await
225218
}
226219

227-
async fn first_async<U>(self, asc: &AsyncConn) -> Result<U, E>
220+
async fn first_async<U>(self, asc: &AsyncConn) -> Result<U, DieselError>
228221
where
229222
U: Send + 'static,
230223
Self: LimitDsl,
231224
Limit<Self>: LoadQuery<'static, Conn, U>,
232225
{
233-
asc.run(|conn| self.first(conn).map_err(E::from)).await
226+
asc.run(|conn| self.first(conn)).await
234227
}
235228
}
236229

237230
#[async_trait]
238-
pub trait AsyncSaveChangesDsl<Conn, AsyncConn, E>
231+
pub trait AsyncSaveChangesDsl<Conn, AsyncConn>
239232
where
240233
Conn: 'static + DieselConnection,
241234
{
242-
async fn save_changes_async<Output>(self, asc: &AsyncConn) -> Result<Output, E>
235+
async fn save_changes_async<Output>(self, asc: &AsyncConn) -> Result<Output, DieselError>
243236
where
244237
Self: Sized,
245238
Conn: diesel::query_dsl::UpdateAndFetchResults<Self, Output>,
246239
Output: Send + 'static;
247240
}
248241

249242
#[async_trait]
250-
impl<T, AsyncConn, Conn, E> AsyncSaveChangesDsl<Conn, AsyncConn, E> for T
243+
impl<T, AsyncConn, Conn> AsyncSaveChangesDsl<Conn, AsyncConn> for T
251244
where
252245
T: 'static + Send + Sync + diesel::SaveChangesDsl<Conn>,
253246
Conn: 'static + DieselConnection,
254-
AsyncConn: Send + Sync + AsyncConnection<Conn, E>,
255-
E: 'static + Send + From<DieselError>,
247+
AsyncConn: Send + Sync + AsyncConnection<Conn>,
256248
{
257-
async fn save_changes_async<Output>(self, asc: &AsyncConn) -> Result<Output, E>
249+
async fn save_changes_async<Output>(self, asc: &AsyncConn) -> Result<Output, DieselError>
258250
where
259251
Conn: diesel::query_dsl::UpdateAndFetchResults<Self, Output>,
260252
Output: Send + 'static,
261253
{
262-
asc.run(|conn| self.save_changes(conn).map_err(E::from))
263-
.await
254+
asc.run(|conn| self.save_changes(conn)).await
264255
}
265256
}

src/connection.rs

+6-8
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
//! An async wrapper around a [`diesel::Connection`].
22
3-
use crate::{ConnectionError, ConnectionResult};
43
use async_trait::async_trait;
54
use diesel::r2d2::R2D2Connection;
65
use std::sync::{Arc, Mutex, MutexGuard};
@@ -31,31 +30,30 @@ impl<C> Connection<C> {
3130
}
3231

3332
#[async_trait]
34-
impl<Conn> crate::AsyncSimpleConnection<Conn, ConnectionError> for Connection<Conn>
33+
impl<Conn> crate::AsyncSimpleConnection<Conn> for Connection<Conn>
3534
where
3635
Conn: 'static + R2D2Connection,
3736
{
3837
#[inline]
39-
async fn batch_execute_async(&self, query: &str) -> ConnectionResult<()> {
38+
async fn batch_execute_async(&self, query: &str) -> Result<(), diesel::result::Error> {
4039
let diesel_conn = Connection(self.0.clone());
4140
let query = query.to_string();
4241
task::spawn_blocking(move || diesel_conn.inner().batch_execute(&query))
4342
.await
4443
.unwrap() // Propagate panics
45-
.map_err(ConnectionError::from)
4644
}
4745
}
4846

4947
#[async_trait]
50-
impl<Conn> crate::AsyncConnection<Conn, ConnectionError> for Connection<Conn>
48+
impl<Conn> crate::AsyncConnection<Conn> for Connection<Conn>
5149
where
5250
Conn: 'static + R2D2Connection,
53-
Connection<Conn>: crate::AsyncSimpleConnection<Conn, ConnectionError>,
51+
Connection<Conn>: crate::AsyncSimpleConnection<Conn>,
5452
{
5553
type OwnedConnection = Connection<Conn>;
5654

57-
async fn get_owned_connection(&self) -> Result<Self::OwnedConnection, ConnectionError> {
58-
Ok(Connection(self.0.clone()))
55+
async fn get_owned_connection(&self) -> Self::OwnedConnection {
56+
Connection(self.0.clone())
5957
}
6058

6159
fn as_sync_conn(owned: &Self::OwnedConnection) -> MutexGuard<'_, Conn> {

0 commit comments

Comments
 (0)