@@ -87,41 +87,47 @@ struct ConnectionData {
87
87
}
88
88
89
89
bool is_connecting () {
90
- return (state == EConnectionState::CONNECTING);
90
+ return (state. load () == EConnectionState::CONNECTING);
91
91
}
92
92
93
93
bool is_close_requested () {
94
94
return is_marked_close;
95
95
}
96
96
97
97
auto get_state () {
98
- return state;
98
+ return state. load () ;
99
99
}
100
100
101
101
lws* get_conn () {
102
102
return wsi;
103
103
}
104
104
105
- lws_context* get_ctx () {
106
- return lws_ctx.get ();
105
+ WebsocketTlsTPM* get_owner () {
106
+ return owner.load ();
107
+ }
108
+
109
+ void set_owner (WebsocketTlsTPM* o) {
110
+ owner = o;
107
111
}
108
112
109
113
public:
114
+ // This public block will only be used from client loop thread, no locking needed
110
115
// Openssl context, must be destroyed in this order
111
116
std::unique_ptr<SSL_CTX> sec_context;
112
117
std::unique_ptr<OSSL_LIB_CTX> sec_lib_context;
113
-
114
118
// libwebsockets state
115
119
std::unique_ptr<lws_context> lws_ctx;
116
- lws* wsi;
117
120
118
- WebsocketTlsTPM* owner ;
121
+ lws* wsi ;
119
122
120
123
private:
124
+ std::atomic<WebsocketTlsTPM*> owner;
125
+
121
126
std::thread::id lws_thread_id;
122
- bool is_running;
123
- bool is_marked_close;
124
- EConnectionState state;
127
+
128
+ std::atomic_bool is_running;
129
+ std::atomic_bool is_marked_close;
130
+ std::atomic<EConnectionState> state;
125
131
};
126
132
127
133
struct WebsocketMessage {
@@ -140,7 +146,7 @@ struct WebsocketMessage {
140
146
// just that these were sent to libwebsockets
141
147
size_t sent_bytes;
142
148
// If libwebsockets has sent all the bytes through the wire
143
- volatile bool message_sent;
149
+ std::atomic_bool message_sent;
144
150
};
145
151
146
152
WebsocketTlsTPM::WebsocketTlsTPM (const WebsocketConnectionOptions& connection_options,
@@ -189,7 +195,12 @@ static int callback_minimal(struct lws* wsi, enum lws_callback_reasons reason, v
189
195
// Get user safely, since on some callbacks (void *user) can be different than what we set
190
196
if (wsi != nullptr ) {
191
197
if (ConnectionData* data = reinterpret_cast <ConnectionData*>(lws_wsi_user (wsi))) {
192
- return data->owner ->process_callback (wsi, static_cast <int >(reason), user, in, len);
198
+ auto owner = data->get_owner ();
199
+ if (owner not_eq nullptr ) {
200
+ return data->get_owner ()->process_callback (wsi, static_cast <int >(reason), user, in, len);
201
+ } else {
202
+ EVLOG_error << " callback_minimal called, but data->owner is nullptr" ;
203
+ }
193
204
}
194
205
}
195
206
@@ -331,11 +342,14 @@ void WebsocketTlsTPM::recv_loop() {
331
342
332
343
while (false == data->is_interupted ()) {
333
344
// Process all messages
334
- while (false == recv_message_queue. empty () ) {
345
+ while (true ) {
335
346
std::string message{};
336
347
337
348
{
338
349
std::lock_guard lk (this ->recv_mutex );
350
+ if (recv_message_queue.empty ())
351
+ break ;
352
+
339
353
message = std::move (recv_message_queue.front ());
340
354
recv_message_queue.pop ();
341
355
}
@@ -459,9 +473,14 @@ void WebsocketTlsTPM::client_loop() {
459
473
460
474
while (n >= 0 && (false == data->is_interupted ())) {
461
475
// Set to -1 for continuous servicing, of required, not recommended
462
- n = lws_service (data->get_ctx (), 0 );
476
+ n = lws_service (data->lws_ctx . get (), 0 );
463
477
464
- if (false == message_queue.empty ()) {
478
+ bool message_queue_empty;
479
+ {
480
+ std::lock_guard<std::mutex> lock (this ->queue_mutex );
481
+ message_queue_empty = message_queue.empty ();
482
+ }
483
+ if (false == message_queue_empty) {
465
484
lws_callback_on_writable (data->get_conn ());
466
485
}
467
486
}
@@ -470,6 +489,7 @@ void WebsocketTlsTPM::client_loop() {
470
489
EVLOG_debug << " Exit client loop with ID: " << std::this_thread::get_id ();
471
490
}
472
491
492
+ // Will be called from external threads as well
473
493
bool WebsocketTlsTPM::connect () {
474
494
if (!this ->initialized ()) {
475
495
return false ;
@@ -485,7 +505,7 @@ bool WebsocketTlsTPM::connect() {
485
505
}
486
506
487
507
auto conn_data = new ConnectionData ();
488
- conn_data->owner = this ;
508
+ conn_data->set_owner ( this ) ;
489
509
490
510
this ->conn_data .reset (conn_data);
491
511
@@ -600,7 +620,8 @@ void WebsocketTlsTPM::on_conn_connected() {
600
620
this ->m_is_connected = true ;
601
621
this ->reconnecting = false ;
602
622
603
- this ->connected_callback (this ->connection_options .security_profile );
623
+ std::thread connected ([this ]() { this ->connected_callback (this ->connection_options .security_profile ); });
624
+ connected.detach ();
604
625
}
605
626
606
627
void WebsocketTlsTPM::on_conn_close () {
@@ -717,11 +738,14 @@ void WebsocketTlsTPM::on_writable() {
717
738
return ;
718
739
}
719
740
720
- while (false == message_queue. empty () ) {
741
+ while (true ) {
721
742
WebsocketMessage* message = nullptr ;
722
743
723
744
{
724
745
std::lock_guard<std::mutex> lock (this ->queue_mutex );
746
+ if (message_queue.empty ()) {
747
+ break ;
748
+ }
725
749
message = message_queue.front ().get ();
726
750
}
727
751
@@ -744,7 +768,7 @@ void WebsocketTlsTPM::on_writable() {
744
768
745
769
EVLOG_debug << " Notifying waiting thread!" ;
746
770
// Notify any waiting thread to check it's state
747
- msg_send_cv.notify_one ();
771
+ msg_send_cv.notify_all ();
748
772
} else {
749
773
EVLOG_debug << " Client writable, sending message part!" ;
750
774
@@ -766,16 +790,18 @@ void WebsocketTlsTPM::request_write() {
766
790
if (this ->m_is_connected ) {
767
791
if (auto * data = conn_data.get ()) {
768
792
if (data->get_conn ()) {
769
- // Notify waiting processing thread to wake up
770
- lws_cancel_service (data->get_ctx ());
793
+ // Notify waiting processing thread to wake up. According to docs it is ok to call from another
794
+ // thread.
795
+ lws_cancel_service (data->lws_ctx .get ());
771
796
}
772
797
}
773
798
} else {
774
799
EVLOG_warning << " Requested write with offline TLS websocket!" ;
775
800
}
776
801
}
777
802
778
- void WebsocketTlsTPM::poll_message (const std::shared_ptr<WebsocketMessage>& msg, bool wait_send) {
803
+ void WebsocketTlsTPM::poll_message (const std::shared_ptr<WebsocketMessage>& msg) {
804
+
779
805
if (std::this_thread::get_id () == conn_data->get_lws_thread_id ()) {
780
806
EVLOG_AND_THROW (std::runtime_error (" Deadlock detected, polling send from client lws thread!" ));
781
807
}
@@ -790,17 +816,17 @@ void WebsocketTlsTPM::poll_message(const std::shared_ptr<WebsocketMessage>& msg,
790
816
// Request a write callback
791
817
request_write ();
792
818
793
- if (wait_send) {
794
- std::unique_lock lock (this ->queue_mutex );
795
- msg_send_cv.wait_for (lock, std::chrono::seconds (10 ), [&] { return (true == msg->message_sent ); });
819
+ {
820
+ std::unique_lock lock (this ->msg_send_cv_mutex );
821
+ if (msg_send_cv.wait_for (lock, std::chrono::seconds (20 ), [&] { return (true == msg->message_sent ); })) {
822
+ EVLOG_info << " Successfully sent last message over TLS websocket!" ;
823
+ } else {
824
+ EVLOG_warning << " Could not send last message over TLS websocket!" ;
825
+ }
796
826
}
797
-
798
- if (msg->message_sent )
799
- EVLOG_info << " Successfully sent last message over TLS websocket!" ;
800
- else
801
- EVLOG_warning << " Could not send last message over TLS websocket!" ;
802
827
}
803
828
829
+ // Will be called from external threads
804
830
bool WebsocketTlsTPM::send (const std::string& message) {
805
831
if (!this ->initialized ()) {
806
832
EVLOG_error << " Could not send message because websocket is not properly initialized." ;
@@ -811,7 +837,7 @@ bool WebsocketTlsTPM::send(const std::string& message) {
811
837
msg->payload = std::move (message);
812
838
msg->protocol = LWS_WRITE_TEXT;
813
839
814
- poll_message (msg, true );
840
+ poll_message (msg);
815
841
816
842
return msg->message_sent ;
817
843
}
@@ -825,7 +851,7 @@ void WebsocketTlsTPM::ping() {
825
851
msg->payload = this ->connection_options .ping_payload ;
826
852
msg->protocol = LWS_WRITE_PING;
827
853
828
- poll_message (msg, true );
854
+ poll_message (msg);
829
855
}
830
856
831
857
int WebsocketTlsTPM::process_callback (void * wsi_ptr, int callback_reason, void * user, void * in, size_t len) {
@@ -949,31 +975,53 @@ int WebsocketTlsTPM::process_callback(void* wsi_ptr, int callback_reason, void*
949
975
950
976
case LWS_CALLBACK_CLIENT_WRITEABLE:
951
977
on_writable ();
952
-
953
- if (false == message_queue.empty ()) {
954
- lws_callback_on_writable (wsi);
978
+ {
979
+ bool message_queue_empty;
980
+ {
981
+ std::lock_guard<std::mutex> lock (this ->queue_mutex );
982
+ message_queue_empty = message_queue.empty ();
983
+ }
984
+ if (false == message_queue_empty) {
985
+ lws_callback_on_writable (wsi);
986
+ }
955
987
}
956
988
break ;
957
989
958
- case LWS_CALLBACK_CLIENT_RECEIVE_PONG:
959
- if (false == message_queue.empty ()) {
990
+ case LWS_CALLBACK_CLIENT_RECEIVE_PONG: {
991
+ bool message_queue_empty;
992
+ {
993
+ std::lock_guard<std::mutex> lock (this ->queue_mutex );
994
+ message_queue_empty = message_queue.empty ();
995
+ }
996
+ if (false == message_queue_empty) {
960
997
lws_callback_on_writable (data->get_conn ());
961
998
}
962
- break ;
999
+ } break ;
963
1000
964
1001
case LWS_CALLBACK_CLIENT_RECEIVE:
965
1002
on_message (in, len);
966
-
967
- if (false == message_queue.empty ()) {
968
- lws_callback_on_writable (data->get_conn ());
1003
+ {
1004
+ bool message_queue_empty;
1005
+ {
1006
+ std::lock_guard<std::mutex> lock (this ->queue_mutex );
1007
+ message_queue_empty = message_queue.empty ();
1008
+ }
1009
+ if (false == message_queue_empty) {
1010
+ lws_callback_on_writable (data->get_conn ());
1011
+ }
969
1012
}
970
1013
break ;
971
1014
972
- case LWS_CALLBACK_EVENT_WAIT_CANCELLED:
973
- if (false == message_queue.empty ()) {
1015
+ case LWS_CALLBACK_EVENT_WAIT_CANCELLED: {
1016
+ bool message_queue_empty;
1017
+ {
1018
+ std::lock_guard<std::mutex> lock (this ->queue_mutex );
1019
+ message_queue_empty = message_queue.empty ();
1020
+ }
1021
+ if (false == message_queue_empty) {
974
1022
lws_callback_on_writable (data->get_conn ());
975
1023
}
976
- break ;
1024
+ } break ;
977
1025
978
1026
default :
979
1027
EVLOG_info << " Callback with unhandled reason: " << reason;
0 commit comments