-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathconn.go
128 lines (113 loc) · 3.18 KB
/
conn.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
package openairt
import (
"context"
"errors"
"fmt"
)
type ServerEventHandler func(ctx context.Context, event ServerEvent)
// Conn is a connection to the OpenAI Realtime API.
type Conn struct {
logger Logger
conn WebSocketConn
}
// Close closes the connection.
func (c *Conn) Close() error {
return c.conn.Close()
}
// SendMessageRaw sends a raw message to the server.
func (c *Conn) SendMessageRaw(ctx context.Context, data []byte) error {
return c.conn.WriteMessage(ctx, MessageText, data)
}
// SendMessage sends a client event to the server.
func (c *Conn) SendMessage(ctx context.Context, msg ClientEvent) error {
data, err := MarshalClientEvent(msg)
if err != nil {
return err
}
return c.SendMessageRaw(ctx, data)
}
// ReadMessageRaw reads a raw message from the server.
func (c *Conn) ReadMessageRaw(ctx context.Context) ([]byte, error) {
messageType, data, err := c.conn.ReadMessage(ctx)
if err != nil {
return nil, err
}
if messageType != MessageText {
return nil, fmt.Errorf("expected text message, got %d", messageType)
}
return data, nil
}
// ReadMessage reads a server event from the server.
func (c *Conn) ReadMessage(ctx context.Context) (ServerEvent, error) {
data, err := c.ReadMessageRaw(ctx)
if err != nil {
return nil, err
}
event, err := UnmarshalServerEvent(data)
if err != nil {
return nil, err
}
return event, nil
}
// Ping sends a ping message to the WebSocket connection.
func (c *Conn) Ping(ctx context.Context) error {
return c.conn.Ping(ctx)
}
// ConnHandler is a handler for a connection to the OpenAI Realtime API.
// It reads messages from the server in a standalone goroutine and calls the registered handlers.
// It is the responsibility of the caller to call Start and Stop.
// The handlers are called in the order they are registered.
// Users should not call ReadMessage directly when using ConnHandler.
type ConnHandler struct {
ctx context.Context
conn *Conn
handlers []ServerEventHandler
errCh chan error
}
// NewConnHandler creates a new ConnHandler with the given context and connection.
func NewConnHandler(ctx context.Context, conn *Conn, handlers ...ServerEventHandler) *ConnHandler {
return &ConnHandler{
ctx: ctx,
conn: conn,
handlers: handlers,
errCh: make(chan error, 1),
}
}
// Start starts the ConnHandler.
func (c *ConnHandler) Start() {
go func() {
err := c.run()
if err != nil {
c.errCh <- err
}
close(c.errCh)
}()
}
// Err returns a channel that receives errors from the ConnHandler.
// This could be used to wait for the goroutine to exit.
// If you don't need to wait for the goroutine to exit, there's no need to call this.
// This must be called after the connection is closed, otherwise it will block indefinitely.
func (c *ConnHandler) Err() <-chan error {
return c.errCh
}
func (c *ConnHandler) run() error {
for {
select {
case <-c.ctx.Done():
return c.ctx.Err()
default:
}
msg, err := c.conn.ReadMessage(c.ctx)
if err != nil {
var permanent *PermanentError
if errors.As(err, &permanent) {
return permanent.Err
}
c.conn.logger.Warnf("read message temporary error: %+v", err)
continue
}
for _, handler := range c.handlers {
handler(c.ctx, msg)
}
}
}