Skip to content

Commit 3e15729

Browse files
Libwebsockets: Fix message loss on reconnect
Signed-off-by: Cornelius Claussen <cc@pionix.de>
1 parent 06fcfa3 commit 3e15729

File tree

3 files changed

+100
-50
lines changed

3 files changed

+100
-50
lines changed

include/ocpp/common/websocket/websocket_base.hpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ enum class ConnectionFailedReason {
4747
///
4848
class WebsocketBase {
4949
protected:
50-
bool m_is_connected;
50+
std::atomic_bool m_is_connected;
5151
WebsocketConnectionOptions connection_options;
5252
std::function<void(const int security_profile)> connected_callback;
5353
std::function<void()> disconnected_callback;
@@ -59,11 +59,11 @@ class WebsocketBase {
5959
websocketpp::connection_hdl handle;
6060
std::mutex reconnect_mutex;
6161
std::mutex connection_mutex;
62-
long reconnect_backoff_ms;
62+
std::atomic_int reconnect_backoff_ms;
6363
websocketpp::transport::timer_handler reconnect_callback;
64-
int connection_attempts;
65-
bool shutting_down;
66-
bool reconnecting;
64+
std::atomic_int connection_attempts;
65+
std::atomic_bool shutting_down;
66+
std::atomic_bool reconnecting;
6767

6868
/// \brief Indicates if the required callbacks are registered
6969
/// \returns true if the websocket is properly initialized

include/ocpp/common/websocket/websocket_tls_tpm.hpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ class WebsocketTlsTPM final : public WebsocketBase {
6767

6868
void request_write();
6969

70-
void poll_message(const std::shared_ptr<WebsocketMessage>& msg, bool wait_sendaf);
70+
void poll_message(const std::shared_ptr<WebsocketMessage>& msg);
7171

7272
private:
7373
std::shared_ptr<EvseSecurity> evse_security;
@@ -79,8 +79,10 @@ class WebsocketTlsTPM final : public WebsocketBase {
7979
std::condition_variable conn_cv;
8080

8181
std::mutex queue_mutex;
82+
8283
std::queue<std::shared_ptr<WebsocketMessage>> message_queue;
8384
std::condition_variable msg_send_cv;
85+
std::mutex msg_send_cv_mutex;
8486

8587
std::unique_ptr<std::thread> recv_message_thread;
8688
std::mutex recv_mutex;

lib/ocpp/common/websocket/websocket_tls_tpm.cpp

+92-44
Original file line numberDiff line numberDiff line change
@@ -87,41 +87,47 @@ struct ConnectionData {
8787
}
8888

8989
bool is_connecting() {
90-
return (state == EConnectionState::CONNECTING);
90+
return (state.load() == EConnectionState::CONNECTING);
9191
}
9292

9393
bool is_close_requested() {
9494
return is_marked_close;
9595
}
9696

9797
auto get_state() {
98-
return state;
98+
return state.load();
9999
}
100100

101101
lws* get_conn() {
102102
return wsi;
103103
}
104104

105-
lws_context* get_ctx() {
106-
return lws_ctx.get();
105+
WebsocketTlsTPM* get_owner() {
106+
return owner.load();
107+
}
108+
109+
void set_owner(WebsocketTlsTPM* o) {
110+
owner = o;
107111
}
108112

109113
public:
114+
// This public block will only be used from client loop thread, no locking needed
110115
// Openssl context, must be destroyed in this order
111116
std::unique_ptr<SSL_CTX> sec_context;
112117
std::unique_ptr<OSSL_LIB_CTX> sec_lib_context;
113-
114118
// libwebsockets state
115119
std::unique_ptr<lws_context> lws_ctx;
116-
lws* wsi;
117120

118-
WebsocketTlsTPM* owner;
121+
lws* wsi;
119122

120123
private:
124+
std::atomic<WebsocketTlsTPM*> owner;
125+
121126
std::thread::id lws_thread_id;
122-
bool is_running;
123-
bool is_marked_close;
124-
EConnectionState state;
127+
128+
std::atomic_bool is_running;
129+
std::atomic_bool is_marked_close;
130+
std::atomic<EConnectionState> state;
125131
};
126132

127133
struct WebsocketMessage {
@@ -140,7 +146,7 @@ struct WebsocketMessage {
140146
// just that these were sent to libwebsockets
141147
size_t sent_bytes;
142148
// If libwebsockets has sent all the bytes through the wire
143-
volatile bool message_sent;
149+
std::atomic_bool message_sent;
144150
};
145151

146152
WebsocketTlsTPM::WebsocketTlsTPM(const WebsocketConnectionOptions& connection_options,
@@ -189,7 +195,12 @@ static int callback_minimal(struct lws* wsi, enum lws_callback_reasons reason, v
189195
// Get user safely, since on some callbacks (void *user) can be different than what we set
190196
if (wsi != nullptr) {
191197
if (ConnectionData* data = reinterpret_cast<ConnectionData*>(lws_wsi_user(wsi))) {
192-
return data->owner->process_callback(wsi, static_cast<int>(reason), user, in, len);
198+
auto owner = data->get_owner();
199+
if (owner not_eq nullptr) {
200+
return data->get_owner()->process_callback(wsi, static_cast<int>(reason), user, in, len);
201+
} else {
202+
EVLOG_error << "callback_minimal called, but data->owner is nullptr";
203+
}
193204
}
194205
}
195206

@@ -331,11 +342,14 @@ void WebsocketTlsTPM::recv_loop() {
331342

332343
while (false == data->is_interupted()) {
333344
// Process all messages
334-
while (false == recv_message_queue.empty()) {
345+
while (true) {
335346
std::string message{};
336347

337348
{
338349
std::lock_guard lk(this->recv_mutex);
350+
if (recv_message_queue.empty())
351+
break;
352+
339353
message = std::move(recv_message_queue.front());
340354
recv_message_queue.pop();
341355
}
@@ -459,9 +473,14 @@ void WebsocketTlsTPM::client_loop() {
459473

460474
while (n >= 0 && (false == data->is_interupted())) {
461475
// Set to -1 for continuous servicing, of required, not recommended
462-
n = lws_service(data->get_ctx(), 0);
476+
n = lws_service(data->lws_ctx.get(), 0);
463477

464-
if (false == message_queue.empty()) {
478+
bool message_queue_empty;
479+
{
480+
std::lock_guard<std::mutex> lock(this->queue_mutex);
481+
message_queue_empty = message_queue.empty();
482+
}
483+
if (false == message_queue_empty) {
465484
lws_callback_on_writable(data->get_conn());
466485
}
467486
}
@@ -470,6 +489,7 @@ void WebsocketTlsTPM::client_loop() {
470489
EVLOG_debug << "Exit client loop with ID: " << std::this_thread::get_id();
471490
}
472491

492+
// Will be called from external threads as well
473493
bool WebsocketTlsTPM::connect() {
474494
if (!this->initialized()) {
475495
return false;
@@ -485,7 +505,7 @@ bool WebsocketTlsTPM::connect() {
485505
}
486506

487507
auto conn_data = new ConnectionData();
488-
conn_data->owner = this;
508+
conn_data->set_owner(this);
489509

490510
this->conn_data.reset(conn_data);
491511

@@ -600,7 +620,8 @@ void WebsocketTlsTPM::on_conn_connected() {
600620
this->m_is_connected = true;
601621
this->reconnecting = false;
602622

603-
this->connected_callback(this->connection_options.security_profile);
623+
std::thread connected([this]() { this->connected_callback(this->connection_options.security_profile); });
624+
connected.detach();
604625
}
605626

606627
void WebsocketTlsTPM::on_conn_close() {
@@ -717,11 +738,14 @@ void WebsocketTlsTPM::on_writable() {
717738
return;
718739
}
719740

720-
while (false == message_queue.empty()) {
741+
while (true) {
721742
WebsocketMessage* message = nullptr;
722743

723744
{
724745
std::lock_guard<std::mutex> lock(this->queue_mutex);
746+
if (message_queue.empty()) {
747+
break;
748+
}
725749
message = message_queue.front().get();
726750
}
727751

@@ -744,7 +768,7 @@ void WebsocketTlsTPM::on_writable() {
744768

745769
EVLOG_debug << "Notifying waiting thread!";
746770
// Notify any waiting thread to check it's state
747-
msg_send_cv.notify_one();
771+
msg_send_cv.notify_all();
748772
} else {
749773
EVLOG_debug << "Client writable, sending message part!";
750774

@@ -766,16 +790,18 @@ void WebsocketTlsTPM::request_write() {
766790
if (this->m_is_connected) {
767791
if (auto* data = conn_data.get()) {
768792
if (data->get_conn()) {
769-
// Notify waiting processing thread to wake up
770-
lws_cancel_service(data->get_ctx());
793+
// Notify waiting processing thread to wake up. According to docs it is ok to call from another
794+
// thread.
795+
lws_cancel_service(data->lws_ctx.get());
771796
}
772797
}
773798
} else {
774799
EVLOG_warning << "Requested write with offline TLS websocket!";
775800
}
776801
}
777802

778-
void WebsocketTlsTPM::poll_message(const std::shared_ptr<WebsocketMessage>& msg, bool wait_send) {
803+
void WebsocketTlsTPM::poll_message(const std::shared_ptr<WebsocketMessage>& msg) {
804+
779805
if (std::this_thread::get_id() == conn_data->get_lws_thread_id()) {
780806
EVLOG_AND_THROW(std::runtime_error("Deadlock detected, polling send from client lws thread!"));
781807
}
@@ -790,17 +816,17 @@ void WebsocketTlsTPM::poll_message(const std::shared_ptr<WebsocketMessage>& msg,
790816
// Request a write callback
791817
request_write();
792818

793-
if (wait_send) {
794-
std::unique_lock lock(this->queue_mutex);
795-
msg_send_cv.wait_for(lock, std::chrono::seconds(10), [&] { return (true == msg->message_sent); });
819+
{
820+
std::unique_lock lock(this->msg_send_cv_mutex);
821+
if (msg_send_cv.wait_for(lock, std::chrono::seconds(20), [&] { return (true == msg->message_sent); })) {
822+
EVLOG_info << "Successfully sent last message over TLS websocket!";
823+
} else {
824+
EVLOG_warning << "Could not send last message over TLS websocket!";
825+
}
796826
}
797-
798-
if (msg->message_sent)
799-
EVLOG_info << "Successfully sent last message over TLS websocket!";
800-
else
801-
EVLOG_warning << "Could not send last message over TLS websocket!";
802827
}
803828

829+
// Will be called from external threads
804830
bool WebsocketTlsTPM::send(const std::string& message) {
805831
if (!this->initialized()) {
806832
EVLOG_error << "Could not send message because websocket is not properly initialized.";
@@ -811,7 +837,7 @@ bool WebsocketTlsTPM::send(const std::string& message) {
811837
msg->payload = std::move(message);
812838
msg->protocol = LWS_WRITE_TEXT;
813839

814-
poll_message(msg, true);
840+
poll_message(msg);
815841

816842
return msg->message_sent;
817843
}
@@ -825,7 +851,7 @@ void WebsocketTlsTPM::ping() {
825851
msg->payload = this->connection_options.ping_payload;
826852
msg->protocol = LWS_WRITE_PING;
827853

828-
poll_message(msg, true);
854+
poll_message(msg);
829855
}
830856

831857
int WebsocketTlsTPM::process_callback(void* wsi_ptr, int callback_reason, void* user, void* in, size_t len) {
@@ -949,31 +975,53 @@ int WebsocketTlsTPM::process_callback(void* wsi_ptr, int callback_reason, void*
949975

950976
case LWS_CALLBACK_CLIENT_WRITEABLE:
951977
on_writable();
952-
953-
if (false == message_queue.empty()) {
954-
lws_callback_on_writable(wsi);
978+
{
979+
bool message_queue_empty;
980+
{
981+
std::lock_guard<std::mutex> lock(this->queue_mutex);
982+
message_queue_empty = message_queue.empty();
983+
}
984+
if (false == message_queue_empty) {
985+
lws_callback_on_writable(wsi);
986+
}
955987
}
956988
break;
957989

958-
case LWS_CALLBACK_CLIENT_RECEIVE_PONG:
959-
if (false == message_queue.empty()) {
990+
case LWS_CALLBACK_CLIENT_RECEIVE_PONG: {
991+
bool message_queue_empty;
992+
{
993+
std::lock_guard<std::mutex> lock(this->queue_mutex);
994+
message_queue_empty = message_queue.empty();
995+
}
996+
if (false == message_queue_empty) {
960997
lws_callback_on_writable(data->get_conn());
961998
}
962-
break;
999+
} break;
9631000

9641001
case LWS_CALLBACK_CLIENT_RECEIVE:
9651002
on_message(in, len);
966-
967-
if (false == message_queue.empty()) {
968-
lws_callback_on_writable(data->get_conn());
1003+
{
1004+
bool message_queue_empty;
1005+
{
1006+
std::lock_guard<std::mutex> lock(this->queue_mutex);
1007+
message_queue_empty = message_queue.empty();
1008+
}
1009+
if (false == message_queue_empty) {
1010+
lws_callback_on_writable(data->get_conn());
1011+
}
9691012
}
9701013
break;
9711014

972-
case LWS_CALLBACK_EVENT_WAIT_CANCELLED:
973-
if (false == message_queue.empty()) {
1015+
case LWS_CALLBACK_EVENT_WAIT_CANCELLED: {
1016+
bool message_queue_empty;
1017+
{
1018+
std::lock_guard<std::mutex> lock(this->queue_mutex);
1019+
message_queue_empty = message_queue.empty();
1020+
}
1021+
if (false == message_queue_empty) {
9741022
lws_callback_on_writable(data->get_conn());
9751023
}
976-
break;
1024+
} break;
9771025

9781026
default:
9791027
EVLOG_info << "Callback with unhandled reason: " << reason;

0 commit comments

Comments
 (0)