diff --git a/include/rtcdcpp/DataChannel.hpp b/include/rtcdcpp/DataChannel.hpp index 288aa4a..ae6abf8 100644 --- a/include/rtcdcpp/DataChannel.hpp +++ b/include/rtcdcpp/DataChannel.hpp @@ -45,6 +45,7 @@ namespace rtcdcpp { #define PPID_BINARY_EMPTY 57 // DataChannel Control Types +#define DC_TYPE_CLOSE 0x04 #define DC_TYPE_OPEN 0x03 #define DC_TYPE_ACK 0x02 diff --git a/include/rtcdcpp/PeerConnection.hpp b/include/rtcdcpp/PeerConnection.hpp index 87da9e6..1e3b6e7 100644 --- a/include/rtcdcpp/PeerConnection.hpp +++ b/include/rtcdcpp/PeerConnection.hpp @@ -56,6 +56,7 @@ struct RTCConfiguration { }; class PeerConnection { + friend class DataChannel; public: struct IceCandidate { IceCandidate(const std::string &candidate, const std::string &sdpMid, int sdpMLineIndex) @@ -150,10 +151,11 @@ class PeerConnection { // DataChannel message parsing void HandleNewDataChannel(ChunkPtr chunk, uint16_t sid); + void HandleDataChannelClose(uint16_t sid); void HandleStringMessage(ChunkPtr chunk, uint16_t sid); void HandleBinaryMessage(ChunkPtr chunk, uint16_t sid); std::shared_ptr logger = GetLogger("rtcdcpp.PeerConnection"); - + void ResetSCTPStream(uint16_t stream_id); }; } diff --git a/include/rtcdcpp/SCTPWrapper.hpp b/include/rtcdcpp/SCTPWrapper.hpp index 3600d63..629536f 100644 --- a/include/rtcdcpp/SCTPWrapper.hpp +++ b/include/rtcdcpp/SCTPWrapper.hpp @@ -54,6 +54,7 @@ class SCTPWrapper { bool Initialize(); void Start(); void Stop(); + void ResetSCTPStream(uint16_t stream_id, uint16_t srs_flags); // int GetStreamCursor(); // void SetStreamCursor(int i); diff --git a/src/DataChannel.cpp b/src/DataChannel.cpp index 4702d22..9049d47 100644 --- a/src/DataChannel.cpp +++ b/src/DataChannel.cpp @@ -46,7 +46,7 @@ DataChannel::DataChannel(PeerConnection *pc, uint16_t stream_id, uint8_t chan_ty error_cb = [](std::string x) { ; }; } -DataChannel::~DataChannel() { ; } +DataChannel::~DataChannel() { DataChannel::Close(); } uint16_t DataChannel::GetStreamID() { return this->stream_id; } @@ -57,9 +57,11 @@ std::string DataChannel::GetLabel() { return this->label; } std::string DataChannel::GetProtocol() { return this->protocol; } /** - * TODO: Close the DataChannel. + * Close the DataChannel. */ -void Close() { ; } +void DataChannel::Close() { + this->pc->ResetSCTPStream(GetStreamID()); +} bool DataChannel::SendString(std::string msg) { std::cerr << "DC: Sending string: " << msg << std::endl; diff --git a/src/PeerConnection.cpp b/src/PeerConnection.cpp index 46e4e71..f409187 100644 --- a/src/PeerConnection.cpp +++ b/src/PeerConnection.cpp @@ -186,11 +186,15 @@ void PeerConnection::OnSCTPMsgReceived(ChunkPtr chunk, uint16_t sid, uint32_t pp if (ppid == PPID_CONTROL) { SPDLOG_TRACE(logger, "Control PPID"); if (chunk->Data()[0] == DC_TYPE_OPEN) { + logger->info("DC TYPE OPEN RECEIVED on SID: {}", sid); SPDLOG_TRACE(logger, "New channel time!"); HandleNewDataChannel(chunk, sid); } else if (chunk->Data()[0] == DC_TYPE_ACK) { SPDLOG_TRACE(logger, "DC ACK"); // HandleDataChannelAck(chunk, sid); XXX: Don't care right now + } else if (chunk->Data()[0] == DC_TYPE_CLOSE) { + SPDLOG_TRACE(logger, "DC CLOSE"); + HandleDataChannelClose(sid); } else { SPDLOG_TRACE(logger, "Unknown msg_type for ppid control: {}", chunk->Data()[0]); } @@ -240,6 +244,15 @@ void PeerConnection::HandleNewDataChannel(ChunkPtr chunk, uint16_t sid) { } } +void PeerConnection::HandleDataChannelClose(uint16_t sid) { + auto cur_channel = GetChannel(sid); + if (!cur_channel) { + logger->warn("Received close for unknown channel: {}", sid); + return; + } + cur_channel->OnClosed(); +} + void PeerConnection::HandleStringMessage(ChunkPtr chunk, uint16_t sid) { auto cur_channel = GetChannel(sid); if (!cur_channel) { @@ -263,12 +276,53 @@ void PeerConnection::HandleBinaryMessage(ChunkPtr chunk, uint16_t sid) { } void PeerConnection::SendStrMsg(std::string str_msg, uint16_t sid) { - auto cur_msg = std::make_shared((const uint8_t *)str_msg.c_str(), str_msg.size()); - this->sctp->GSForSCTP(cur_msg, sid, PPID_STRING); + auto chan = GetChannel(sid); + if (chan) { + auto cur_msg = std::make_shared((const uint8_t *)str_msg.c_str(), str_msg.size()); + this->sctp->GSForSCTP(cur_msg, sid, PPID_STRING); + } else { + throw runtime_error("Datachannel does not exist"); + } } void PeerConnection::SendBinaryMsg(const uint8_t *data, int len, uint16_t sid) { - auto cur_msg = std::make_shared(data, len); - this->sctp->GSForSCTP(cur_msg, sid, PPID_BINARY); + auto chan = GetChannel(sid); + if (chan) { + auto cur_msg = std::make_shared(data, len); + this->sctp->GSForSCTP(cur_msg, sid, PPID_BINARY); + } else { + throw runtime_error("Datachannel does not exist"); + } } + +void PeerConnection::CreateDataChannel(std::string label, std::string protocol) { + uint16_t sid; + if (this->role == Client) { + sid = 0; + logger->info("Client SID"); + } else { + sid = 1; + logger->info("Server SID"); + } + for (int i = sid; i < data_channels.size(); i = i + 2) { + auto iter = data_channels.find(i); + if (iter == data_channels.end()) { + sid = i; + break; + } + } + + this->sctp->SetDataChannelSID(sid); + logger->info("Creating DC on SID: {}", sid); + auto new_channel = std::make_shared(this, sid, DATA_CHANNEL_RELIABLE, label, protocol); + data_channels[sid] = new_channel; + + std::thread create_dc = std::thread(&SCTPWrapper::CreateDCForSCTP, sctp.get(), label, protocol); + logger->info("Spawning create_dc thread"); + create_dc.detach(); +} +void PeerConnection::ResetSCTPStream(uint16_t stream_id) { + this->sctp->ResetSCTPStream(stream_id, SCTP_STREAM_RESET_OUTGOING); +} + } diff --git a/src/SCTPWrapper.cpp b/src/SCTPWrapper.cpp index 4b97641..9e1fa88 100644 --- a/src/SCTPWrapper.cpp +++ b/src/SCTPWrapper.cpp @@ -98,6 +98,47 @@ void SCTPWrapper::OnNotification(union sctp_notification *notify, size_t len) { break; case SCTP_STREAM_RESET_EVENT: SPDLOG_TRACE(logger, "OnNotification(type=SCTP_STREAM_RESET_EVENT)"); + struct sctp_stream_reset_event* reset_event; + reset_event = ¬ify->sn_strreset_event; + uint32_t e_length; + e_length = reset_event->strreset_length; + size_t list_len; + list_len = e_length - sizeof(*reset_event); + list_len /= sizeof(uint16_t); + for (int i = 0; i < list_len; i++) { + uint16_t streamid = reset_event->strreset_stream_list[i]; + uint16_t set_flags; + if (reset_event->strreset_flags != 0) { + if ((reset_event->strreset_flags ^ SCTP_STREAM_RESET_INCOMING_SSN) == 0) { + set_flags = SCTP_STREAM_RESET_OUTGOING; + } + if ((reset_event->strreset_flags ^ SCTP_STREAM_RESET_OUTGOING_SSN) == 0) { + //fires when we close the stream from our side explicity or + //as a result of remote close or some error. + + logger->info("Outgoing stream_id#{} have been reset, calling onClose CB", streamid); + const uint8_t dc_close_data = DC_TYPE_CLOSE; + const uint8_t *dc_close_ptr = &dc_close_data; + OnMsgReceived(dc_close_ptr, sizeof(dc_close_ptr), streamid, PPID_CONTROL); + //The above signals to call our onClose callback + } + if ((reset_event->strreset_flags ^ SCTP_STREAM_RESET_DENIED) == 0) { + logger->error("Stream reset denied by peer"); + } + if ((reset_event->strreset_flags ^ SCTP_STREAM_RESET_FAILED) == 0) { + logger->error("Stream reset failed"); + } + } else { + continue; + } + if (set_flags == SCTP_STREAM_RESET_OUTGOING) { + // Reset the stream when a remote close is received. + logger->info("SCTP Reset received for stream_id#{} from remote", streamid); + ResetSCTPStream(streamid, set_flags); + // This will cause another event SCTP_STREAM_RESET_OUTGOING_SSN + // where we can finally call our callbacks. + } + } break; case SCTP_ASSOC_RESET_EVENT: SPDLOG_TRACE(logger, "OnNotification(type=SCTP_ASSOC_RESET_EVENT)"); @@ -291,10 +332,94 @@ void SCTPWrapper::Stop() { usrsctp_close(sock); sock = nullptr; } + usrsctp_deregister_address(this); +} + +void SCTPWrapper::ResetSCTPStream(uint16_t stream_id, uint16_t srs_flags) { + struct sctp_reset_streams* stream_close = NULL; + size_t no_of_streams = 1; + size_t len = sizeof(stream_close) + sizeof(uint16_t); + stream_close = (sctp_reset_streams *) malloc(len); + memset(stream_close, 0, len); + stream_close->srs_flags = srs_flags; + stream_close->srs_number_streams = 1; + stream_close->srs_stream_list[0] = stream_id; + if (usrsctp_setsockopt(this->sock, IPPROTO_SCTP, SCTP_RESET_STREAMS, stream_close, (socklen_t) len) == -1) { + logger->error("Could not set socket options for SCTP_RESET_STREAMS. errno={}", errno); + } else { + logger->info("SCTP_RESET_STREAMS socket option has been set successfully for SID {}", stream_id); + } + free(stream_close); + stream_close = NULL; } void SCTPWrapper::DTLSForSCTP(ChunkPtr chunk) { this->recv_queue.push(chunk); } +uint16_t SCTPWrapper::GetSid(){ + return this->sid; + } + +dc_open_msg* SCTPWrapper::GetDataChannelData(){ + return this->data; + } + +std::string SCTPWrapper::GetLabel(){ + return this->label; + } +std::string SCTPWrapper::GetProtocol(){ + return this->label; + } +void SCTPWrapper::SetDataChannelSID(uint16_t sid) + { + this->sid = sid; + } +void SCTPWrapper::SendACK() { + struct sctp_sndinfo sinfo = {0}; // + sinfo.snd_sid = GetSid(); + sinfo.snd_ppid = htonl(PPID_CONTROL); + uint8_t payload = DC_TYPE_ACK; + if (usrsctp_sendv(this->sock, &payload, sizeof(uint8_t), NULL, 0, &sinfo, sizeof(sinfo), SCTP_SENDV_SNDINFO, 0) < 0) { + logger->error("Sending ACK failed"); + throw std::runtime_error("Sending ACK failed"); + } else { + logger->info("Ack has gone through, SID: {}", sid); + } +} +void SCTPWrapper::CreateDCForSCTP(std::string label, std::string protocol) { + + std::unique_lock l2(createDCMtx); + while (!this->readyDataChannel) { + createDC.wait(l2); + } + struct sctp_sndinfo sinfo = {0}; + int sid; + sid = this->sid; + sinfo.snd_sid = sid; + sinfo.snd_ppid = htonl(PPID_CONTROL); + + int total_size = sizeof *this->data + label.size() + protocol.size() - (2 * sizeof(char *)); + this->data = (dc_open_msg *)calloc(1, total_size); + this->data->msg_type = DC_TYPE_OPEN; + this->data->chan_type = DATA_CHANNEL_RELIABLE; + this->data->priority = htons(0); // https://tools.ietf.org/html/draft-ietf-rtcweb-data-channel-10#section-6.4 + this->data->reliability = htonl(0); + this->data->label_len = htons(label.length()); + this->data->protocol_len = htons(protocol.length()); + // try to overwrite last two char* from the struct + memcpy(&this->data->label, label.c_str(), label.length()); + memcpy(&this->data->label + label.length(), protocol.c_str(), protocol.length()); + + this->label = label.c_str(); + this->protocol = protocol.c_str(); + + if (started) { + if (usrsctp_sendv(this->sock, this->data, total_size, NULL, 0, &sinfo, sizeof(sinfo), SCTP_SENDV_SNDINFO, 0) < 0) { + logger->error("Failed to send a datachannel open request."); + } else { + logger->info("Datachannel open request has gone through."); + } + } +} // Send a message to the remote connection void SCTPWrapper::GSForSCTP(ChunkPtr chunk, uint16_t sid, uint32_t ppid) { struct sctp_sendv_spa spa = {0};