-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathconnection.rs
98 lines (84 loc) · 2.67 KB
/
connection.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
//! An async wrapper around a [`diesel::Connection`].
use crate::async_traits::AsyncConnection;
use async_trait::async_trait;
use diesel::r2d2::R2D2Connection;
use diesel::result::Error as DieselError;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex, MutexGuard};
use tokio::task;
/// An async-safe analogue of any connection that implements
/// [`diesel::Connection`].
///
/// These connections are created by [`crate::ConnectionManager`].
///
/// All blocking methods within this type delegate to
/// [`tokio::task::spawn_blocking`], meaning they won't block
/// any asynchronous work or threads.
pub struct Connection<C>(Arc<ConnectionInner<C>>);
pub struct ConnectionInner<C> {
pub(crate) inner: Mutex<C>,
pub(crate) broken: AtomicBool,
}
impl<C> Connection<C> {
pub fn new(c: C) -> Self {
Self(Arc::new(ConnectionInner {
inner: Mutex::new(c),
broken: AtomicBool::new(false),
}))
}
pub(crate) fn clone(&self) -> Self {
Self(self.0.clone())
}
pub(crate) fn mark_broken(&self) {
self.0.broken.store(true, Ordering::SeqCst);
}
// Accesses the underlying connection.
//
// As this is a blocking mutex, it's recommended to avoid invoking
// this function from an asynchronous context.
pub(crate) fn inner(&self) -> MutexGuard<'_, C> {
self.0.inner.lock().unwrap()
}
}
#[async_trait]
impl<Conn> crate::AsyncSimpleConnection<Conn> for Connection<Conn>
where
Conn: 'static + R2D2Connection,
{
#[inline]
async fn batch_execute_async(&self, query: &str) -> Result<(), diesel::result::Error> {
if self.is_broken_from_txn() {
return Err(DieselError::BrokenTransactionManager);
}
let diesel_conn = self.clone();
let query = query.to_string();
task::spawn_blocking(move || diesel_conn.inner().batch_execute(&query))
.await
.unwrap() // Propagate panics
}
}
#[async_trait]
impl<Conn> crate::AsyncR2D2Connection<Conn> for Connection<Conn> where Conn: 'static + R2D2Connection
{}
#[async_trait]
impl<Conn> crate::AsyncConnection<Conn> for Connection<Conn>
where
Conn: 'static + R2D2Connection,
Connection<Conn>: crate::AsyncSimpleConnection<Conn>,
{
fn get_owned_connection(&self) -> Self {
self.clone()
}
// Accesses the connection synchronously, protected by a mutex.
//
// Avoid calling from asynchronous contexts.
fn as_sync_conn(&self) -> MutexGuard<'_, Conn> {
self.inner()
}
fn as_async_conn(&self) -> &Connection<Conn> {
self
}
fn is_broken_from_txn(&self) -> bool {
self.0.broken.load(Ordering::SeqCst)
}
}