diff --git a/SRTHaishinKit/Sources/SRT/SRTConnection.swift b/SRTHaishinKit/Sources/SRT/SRTConnection.swift index 6b4eb639c..48e455dfa 100644 --- a/SRTHaishinKit/Sources/SRT/SRTConnection.swift +++ b/SRTHaishinKit/Sources/SRT/SRTConnection.swift @@ -60,7 +60,7 @@ public actor SRTConnection: NetworkConnection { do { try await socket.open(addr, mode: mode, options: options) self.uri = uri - connected = await socket.status == SRTS_CONNECTED + connected = await socket.status == .connected continuation.resume() } catch { continuation.resume(throwing: error) @@ -100,13 +100,13 @@ public actor SRTConnection: NetworkConnection { } await networkMonitor?.stopRunning() for client in clients { - await client.close() + await client.stopRunning() } clients.removeAll() for stream in streams { await stream.close() } - await socket?.close() + await socket?.stopRunning() connected = false } diff --git a/SRTHaishinKit/Sources/SRT/SRTSocket.swift b/SRTHaishinKit/Sources/SRT/SRTSocket.swift index f9ece37a4..4585b587b 100644 --- a/SRTHaishinKit/Sources/SRT/SRTSocket.swift +++ b/SRTHaishinKit/Sources/SRT/SRTSocket.swift @@ -11,6 +11,51 @@ final actor SRTSocket { case illegalState(message: String) } + enum Status: Int, CustomDebugStringConvertible { + case unknown + case `init` + case opened + case listening + case connecting + case connected + case broken + case closing + case closed + case nonexist + + var debugDescription: String { + switch self { + case .unknown: + return "unknown" + case .`init`: + return "init" + case .opened: + return "opened" + case .listening: + return "listening" + case .connecting: + return "connecting" + case .connected: + return "connected" + case .broken: + return "broken" + case .closing: + return "closing" + case .closed: + return "closed" + case .nonexist: + return "nonexist" + } + } + + init?(_ status: SRT_SOCKSTATUS) { + self.init(rawValue: Int(status.rawValue)) + defer { + logger.trace(debugDescription) + } + } + } + var inputs: AsyncStream { AsyncStream { continuation in // If Task.detached is not used, closing will result in a deadlock. @@ -44,50 +89,25 @@ final actor SRTSocket { } var performanceData: SRTPerformanceData { - return .init(mon: perf) + .init(mon: perf) } - private(set) var mode: SRTMode = .caller - private(set) var perf: CBytePerfMon = .init() - private(set) var socket: SRTSOCKET = SRT_INVALID_SOCK - private(set) var status: SRT_SOCKSTATUS = SRTS_INIT { - didSet { - guard status != oldValue else { - return - } - switch status { - case SRTS_INIT: // 1 - logger.trace("SRT Socket Init") - case SRTS_OPENED: - logger.info("SRT Socket opened") - case SRTS_LISTENING: - logger.trace("SRT Socket Listening") - case SRTS_CONNECTING: - logger.trace("SRT Socket Connecting") - case SRTS_CONNECTED: - logger.info("SRT Socket Connected") - didConnected() - case SRTS_BROKEN: - logger.warn("SRT Socket Broken") - close() - case SRTS_CLOSING: - logger.trace("SRT Socket Closing") - case SRTS_CLOSED: - logger.info("SRT Socket Closed") - case SRTS_NONEXIST: - logger.warn("SRT Socket Not Exist") - default: - break - } - } + + var status: Status { + .init(srt_getsockstate(socket)) ?? .unknown } - private(set) var options: [SRTSocketOption: any Sendable] = [:] + private(set) var isRunning = false + private var perf: CBytePerfMon = .init() + private var socket: SRTSOCKET = SRT_INVALID_SOCK + private var options: [SRTSocketOption: any Sendable] = [:] private var outputs: AsyncStream.Continuation? { didSet { oldValue?.finish() } } - private var connected = false + private var connected: Bool { + status == .connected + } private var windowSizeC: Int32 = 1024 * 4 private lazy var incomingBuffer: Data = .init(count: Int(windowSizeC)) @@ -102,20 +122,13 @@ final actor SRTSocket { if incomingBuffer.count < windowSizeC { incomingBuffer = .init(count: Int(windowSizeC)) } - status = srt_getsockstate(socket) - switch status { - case SRTS_CONNECTED: - didConnected() - default: - break - } + await startRunning() } func open(_ addr: sockaddr_in, mode: SRTMode, options: [SRTSocketOption: any Sendable] = [:]) throws { guard socket == SRT_INVALID_SOCK else { return } - self.mode = mode // prepare socket socket = srt_create_socket() if socket == SRT_INVALID_SOCK { @@ -150,18 +163,6 @@ final actor SRTSocket { throw makeSocketError() } } - status = srt_getsockstate(socket) - } - - func close() { - guard socket != SRT_INVALID_SOCK else { - return - } - srt_close(socket) - status = srt_getsockstate(socket) - socket = SRT_INVALID_SOCK - outputs = nil - connected = false } func send(_ data: Data) throws { @@ -173,7 +174,7 @@ final actor SRTSocket { } } - func configure(_ binding: SRTSocketOption.Binding) -> Bool { + private func configure(_ binding: SRTSocketOption.Binding) -> Bool { let failures = SRTSocketOption.configure(socket, binding: binding, options: options) guard failures.isEmpty else { logger.error(failures) @@ -189,24 +190,11 @@ final actor SRTSocket { return srt_bstats(socket, &perf, 1) } - private func didConnected() { - connected = true - let stream = AsyncStream { continuation in - self.outputs = continuation - } - Task { - for await data in stream where connected { - let result = sendmsg(data) - if result == -1 { - close() - } - } - } - } - private func makeSocketError() -> SRTError { let error_message = String(cString: srt_getlasterror_str()) - logger.error(error_message) + defer { + logger.error(error_message) + } return .illegalState(message: error_message) } @@ -236,6 +224,37 @@ final actor SRTSocket { } } +extension SRTSocket: AsyncRunner { + // MARK: AsyncRunner + func startRunning() async { + guard !isRunning else { + return + } + let stream = AsyncStream { continuation in + self.outputs = continuation + } + Task { + for await data in stream { + let result = sendmsg(data) + if result == -1 { + await stopRunning() + } + } + } + isRunning = true + } + + func stopRunning() async { + guard isRunning else { + return + } + srt_close(socket) + socket = SRT_INVALID_SOCK + outputs = nil + isRunning = false + } +} + extension SRTSocket: NetworkTransportReporter { // MARK: NetworkTransportReporter func makeNetworkTransportReport() -> NetworkTransportReport { @@ -243,8 +262,8 @@ extension SRTSocket: NetworkTransportReporter { let performanceData = self.performanceData return .init( queueBytesOut: Int(performanceData.byteSndBuf), - totalBytesIn: Int(performanceData.byteSentTotal), - totalBytesOut: Int(performanceData.byteRecvTotal) + totalBytesIn: Int(performanceData.byteRecvTotal), + totalBytesOut: Int(performanceData.byteSentTotal) ) }