Skip to content

Commit

Permalink
Enable access control in SRT listener mode.
Browse files Browse the repository at this point in the history
  • Loading branch information
shogo4405 committed Feb 24, 2025
1 parent 84837a2 commit 6f52a2c
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 46 deletions.
2 changes: 1 addition & 1 deletion Examples/Preference.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ struct Preference: Sendable {
// Temp
static nonisolated(unsafe) var `default` = Preference()

var uri: String? = "rtmp://192.168.1.6/live"
var uri: String? = "srt://:9002?streamid=hello&passphrase=passphrasepassphrasepassphrasepassphrasepassphrasepassphrase"
var streamName: String? = "live"
}
62 changes: 41 additions & 21 deletions SRTHaishinKit/Sources/SRT/SRTConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ public actor SRTConnection: NetworkConnection {

private var socket: SRTSocket? {
didSet {
guard let socket else {
return
}
Task {
guard let socket else {
return
}
let networkMonitor = await socket.makeNetworkMonitor()
self.networkMonitor = networkMonitor
await networkMonitor.startRunning()
Expand All @@ -38,10 +38,19 @@ public actor SRTConnection: NetworkConnection {
}
}
}
Task {
await oldValue?.stopRunning()
}
}
}
private var streams: [SRTStream] = []
private var listener: SRTSocket?
private var listener: SRTSocket? {
didSet {
Task {
await oldValue?.stopRunning()
}
}
}
private var networkMonitor: NetworkMonitor?

/// The SRT's performance data.
Expand Down Expand Up @@ -70,14 +79,14 @@ public actor SRTConnection: NetworkConnection {
/// Creates a connection to the server or waits for an incoming connection.
///
/// - Parameters:
/// - url: You can specify connection options in the URL. This follows the standard SRT format.
/// - uri: You can specify connection options in the URL. This follows the standard SRT format.
///
/// - srt://192.168.1.1:9000?mode=caller
/// - Connect to the specified server.
/// - srt://:9000?mode=listener
/// - Wait for connections as a server.
public func connect(_ url: URL?) async throws {
guard let uri = uri, let scheme = uri.scheme, let host = uri.host, let port = uri.port, scheme == "srt" else {
public func connect(_ uri: URL?) async throws {
guard let uri, let scheme = uri.scheme, let host = uri.host, let port = uri.port, scheme == "srt" else {
throw Error.unsupportedUri(uri)
}
guard let mode = SRTSocketOption.getMode(uri: uri) else {
Expand Down Expand Up @@ -105,25 +114,26 @@ public actor SRTConnection: NetworkConnection {
}
}

/// Closes the connection from the server.
public func close() async throws {
guard connected else {
throw Error.invalidState
/// Closes a connection.
public func close() async {
guard uri != nil else {
return
}
await networkMonitor?.stopRunning()
networkMonitor = nil
for stream in streams {
await stream.close()
}
await socket?.stopRunning()
await listener?.stopRunning()
socket = nil
listener = nil
uri = nil
connected = false
}

func send(_ data: Data) async {
do {
try await socket?.send(data)
} catch {
try? await close()
await close()
}
}

Expand Down Expand Up @@ -151,19 +161,29 @@ public actor SRTConnection: NetworkConnection {
}
}

func acceptSocket() async {
guard let listener else {
return
}
do {
socket = try await listener.accept()
// It is a one-by-one connection and stops once the first connection is established.
self.listener = nil
connected = true
} catch {
logger.error(error)
await acceptSocket()
}
}

private func setMode(_ mode: SRTMode, socket: SRTSocket) {
switch mode {
case .caller:
self.socket = socket
case .listener:
listener = socket
Task {
for await accept in await socket.accept {
self.socket = accept
await listener?.stopRunning()
listener = nil
connected = true
}
await acceptSocket()
}
}
}
Expand Down
52 changes: 29 additions & 23 deletions SRTHaishinKit/Sources/SRT/SRTSocket.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ final actor SRTSocket {
static let payloadSize: Int = 1316

enum Error: Swift.Error {
case rejected
case notConnected
case illegalState(message: String)
}
Expand Down Expand Up @@ -72,22 +73,6 @@ final actor SRTSocket {
}
}

var accept: AsyncStream<SRTSocket> {
AsyncStream<SRTSocket> { continuation in
Task.detached {
repeat {
do {
let client = try await self.accept()
continuation.yield(client)
try await Task.sleep(nanoseconds: 1_000_000_000)
} catch {
continuation.finish()
}
} while await self.connected
}
}
}

var performanceData: SRTPerformanceData {
.init(mon: perf)
}
Expand Down Expand Up @@ -122,10 +107,9 @@ final actor SRTSocket {
if incomingBuffer.count < windowSizeC {
incomingBuffer = .init(count: Int(windowSizeC))
}
await startRunning()
}

func open(_ addr: sockaddr_in, mode: SRTMode, options: [SRTSocketOption: any Sendable] = [:]) throws {
func open(_ addr: sockaddr_in, mode: SRTMode, options: [SRTSocketOption: any Sendable] = [:]) async throws {
guard socket == SRT_INVALID_SOCK else {
return
}
Expand Down Expand Up @@ -163,6 +147,25 @@ final actor SRTSocket {
throw makeSocketError()
}
}
await startRunning()
}

func accept() async throws -> SRTSocket {
try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<SRTSocket, Swift.Error>) in
Task.detached { [self] in
do {
let accept = srt_accept(await socket, nil, nil)
guard -1 < accept else {
throw await makeSocketError()
}
let socket = try await SRTSocket(socket: accept)
socket.stopRunning()
continuation.resume(returning: socket)
} catch {
continuation.resume(throwing: error)
}
}
}
}

func send(_ data: Data) throws {
Expand All @@ -174,6 +177,14 @@ final actor SRTSocket {
}
}

func getOption(_ option: SRTSocketOption) throws -> String? {
return String(data: try option.getOption(socket), encoding: .ascii)
}

private func getOption(_ option: SRTSocketOption) throws -> Data {
return try option.getOption(socket)
}

private func configure(_ binding: SRTSocketOption.Binding) -> Bool {
let failures = SRTSocketOption.configure(socket, binding: binding, options: options)
guard failures.isEmpty else {
Expand All @@ -198,11 +209,6 @@ final actor SRTSocket {
return .illegalState(message: error_message)
}

private func accept() async throws -> SRTSocket {
let accept = srt_accept(socket, nil, nil)
return try await SRTSocket(socket: accept)
}

@inline(__always)
private func sendmsg(_ data: Data) -> Int32 {
return data.withUnsafeBytes { pointer in
Expand Down
2 changes: 1 addition & 1 deletion SRTHaishinKit/Sources/SRT/SRTSocketOption.swift
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ enum SRTSocketOption: String, Sendable {
if uri.host?.isEmpty == true {
return .listener
}
return nil
return .caller
}
}

Expand Down
31 changes: 31 additions & 0 deletions SRTHaishinKit/Tests/SRT/SRTConnectionTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import Foundation
import Testing

import libsrt
@testable import SRTHaishinKit

@Suite struct SRTConnectionTests {
@Test func streamid_success() async throws {
let listener = SRTConnection()
try await listener.connect(URL(string: "srt://:10000?streamid=test"))
let connection = SRTConnection()
try await connection.connect(URL(string: "srt://127.0.0.1:10000?streamid=test"))
await connection.close()
await listener.close()
}

@Test func streamid_failed_success() async throws {
let listener = SRTConnection()
try await listener.connect(URL(string: "srt://:10001?streamid=test&passphrase=a546994dbf25a0823f0cbadff9cc5088k9e7c2027e8e40933a04ef574bc61cd4a"))
let connection1 = SRTConnection()
await #expect(throws: SRTError.self) {
try await connection1.connect(URL(string: "srt://127.0.0.1:10001?streamid=test2&passphrase=a546994dbf25a0823f0cbadff9cc5088k9e7c2027e8e40933a04ef574bc61cd4"))
}
let connection2 = SRTConnection()
try await connection2.connect(URL(string: "srt://127.0.0.1:10001?streamid=test&passphrase=a546994dbf25a0823f0cbadff9cc5088k9e7c2027e8e40933a04ef574bc61cd4a"))
await #expect(connection2.connected == true)
await connection1.close()
await connection2.close()
await listener.close()
}
}
1 change: 1 addition & 0 deletions SRTHaishinKit/Tests/SRT/SRTSocketOptionTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ import libsrt
#expect(SRTSocketOption.getMode(uri: URL(string: "srt://192.168.1.1:9000?mode=client")) == SRTMode.caller)
#expect(SRTSocketOption.getMode(uri: URL(string: "srt://192.168.1.1:9000?mode=listener")) == SRTMode.listener)
#expect(SRTSocketOption.getMode(uri: URL(string: "srt://192.168.1.1:9000?mode=server")) == SRTMode.listener)
#expect(SRTSocketOption.getMode(uri: URL(string: "srt://192.168.1.1:9000")) == SRTMode.caller)
#expect(SRTSocketOption.getMode(uri: URL(string: "srt://:9000")) == SRTMode.listener)
}
}

0 comments on commit 6f52a2c

Please sign in to comment.