Skip to content

Commit

Permalink
Handle sse.Read error inside loop
Browse files Browse the repository at this point in the history
It looks more familiar, is easier to understand, sse.Read in the `range` clause looks much nicer, less explanation needed, more budget for inlining (one closure only), error handling before success path and not after, future-proof in case per-event errors are needed.
  • Loading branch information
tmaxmax committed Dec 28, 2024
1 parent 4153657 commit b08a688
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 46 deletions.
13 changes: 6 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ If you're here just to read ChatGPT's, Claude's or whichever LLM's response stre

```go
req, _ := http.NewRequestWithContext(ctx, http.MethodPost, "https://api.yourllm.com/v1/chat/completions", payload)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+yourKey)

res, err := http.DefaultClient.Do(req)
Expand All @@ -58,14 +59,12 @@ if err != nil {
}
defer res.Body.Close() // don't forget!!

events, errf := sse.Read(res, nil)
for ev := range events {
for ev, err := range sse.Read(res, nil) {
if err != nil {
// handle read error
break // can end the loop as Read stops on first error anyway
}
// Do something with the events, parse the JSON or whatever.
// Only valid events will be here, if there's an error iteration stops.
}
if err := errf(); err != nil {
// Handle any reading errors. This function must be called
// after iterating over the events!
}
```

Expand Down
10 changes: 7 additions & 3 deletions client_connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,14 +168,18 @@ func (c *Connection) read(r io.Reader, setRetry func(time.Duration)) error {
return p
}

events, errf := read(pf, c.lastEventID, func(r int64) { setRetry(time.Duration(r) * time.Millisecond) }, false)
events(func(e Event) bool {
var readErr error
read(pf, c.lastEventID, func(r int64) { setRetry(time.Duration(r) * time.Millisecond) }, false)(func(e Event, err error) bool {
if err != nil {
readErr = err
return false
}
c.lastEventID = e.LastEventID
c.dispatch(e)
return true
})

return errf()
return readErr
}

// Connect sends the request the connection was created with to the server
Expand Down
13 changes: 7 additions & 6 deletions cmd/llm/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,13 @@ func main() {
return
}

events, errf := sse.Read(res.Body, nil)
for ev := range events {
for ev, err := range sse.Read(res.Body, nil) {
if err != nil {
fmt.Fprintf(os.Stderr, "while reading response body: %v\n", err)
// Can return – Read stops after first error and no subsequent events are parsed.
return
}

var data struct {
Choices []struct {
Delta struct {
Expand All @@ -75,8 +80,4 @@ func main() {

fmt.Printf("%s ", data.Choices[0].Delta.Content)
}
if err := errf(); err != nil {
fmt.Fprintf(os.Stderr, "while reading response body: %v\n", err)
return
}
}
41 changes: 21 additions & 20 deletions event.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,11 @@ type ReadConfig struct {
MaxEventSize int
}

// Read parses an SSE stream and yields all incoming events.
// On any encountered errors Read stops execution. The error can be retrieved
// using the returned error function. If EOF is reached, the Read operation
// is considered successful and the error function will return a nil value.
// Read parses an SSE stream and yields all incoming events,
// On any encountered errors iteration stops and no further events are parsed –
// the loop can safely be ended on error. If EOF is reached, the Read operation
// is considered successful and no error is returned. An Event will never
// be yielded together with an error.
//
// Read is especially useful for parsing responses from services which
// communicate using SSE but not over long-lived connections – for example,
Expand All @@ -48,7 +49,7 @@ type ReadConfig struct {
//
// Read provides no way to handle the "retry" field and doesn't handle retrying.
// Use a Client and a Connection if you need to retry requests.
func Read(r io.Reader, cfg *ReadConfig) (func(func(Event) bool), func() error) {
func Read(r io.Reader, cfg *ReadConfig) func(func(Event, error) bool) {
pf := func() *parser.Parser {
p := parser.New(r)
if cfg != nil && cfg.MaxEventSize > 0 {
Expand All @@ -65,24 +66,16 @@ func Read(r io.Reader, cfg *ReadConfig) (func(func(Event) bool), func() error) {
return read(pf, "", nil, true)
}

func read(pf func() *parser.Parser, lastEventID string, onRetry func(int64), ignoreEOF bool) (func(func(Event) bool), func() error) {
var err error
errf := func() error {
if ignoreEOF && err == io.EOF { //nolint:errorlint // this is our error
err = nil
}
return err
}

return func(yield func(Event) bool) {
func read(pf func() *parser.Parser, lastEventID string, onRetry func(int64), ignoreEOF bool) func(func(Event, error) bool) {
return func(yield func(Event, error) bool) {
p := pf()

typ, sb, dirty := "", strings.Builder{}, false
doYield := func(data string) bool {
if data != "" {
data = data[:len(data)-1]
}
return yield(Event{LastEventID: lastEventID, Data: data, Type: typ})
return yield(Event{LastEventID: lastEventID, Data: data, Type: typ}, nil)
}

for f := (parser.Field{}); p.Next(&f); {
Expand Down Expand Up @@ -124,9 +117,17 @@ func read(pf func() *parser.Parser, lastEventID string, onRetry func(int64), ign
}
}

err = p.Err()
if dirty && err == io.EOF { //nolint:errorlint // Our scanner returns io.EOF unwrapped
doYield(sb.String())
err := p.Err()
isEOF := err == io.EOF //nolint:errorlint // Our scanner returns io.EOF unwrapped

if dirty && isEOF {
if !doYield(sb.String()) {
return
}
}

if err != nil && !(ignoreEOF && isEOF) {
yield(Event{}, err)
}
}, errf
}
}
28 changes: 18 additions & 10 deletions event_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,35 +18,43 @@ func TestRead(t *testing.T) {

var recv []sse.Event

events, errf := sse.Read(response, nil)
events(func(e sse.Event) bool {
events := sse.Read(response, nil)
events(func(e sse.Event, err error) bool {
tests.Equal(t, err, nil, "unexpected error")
recv = append(recv, e)
return true
})

tests.Equal(t, errf(), nil, "unexpected error")
tests.DeepEqual(t, recv, []sse.Event{{Data: "Hello World!"}}, "incorrect result")

t.Run("Buffer", func(t *testing.T) {
_, _ = response.Seek(0, io.SeekStart)

events, errf := sse.Read(response, &sse.ReadConfig{MaxEventSize: 3})
events(func(sse.Event) bool { return true })
tests.Expect(t, errf() != nil, "should fail because of too small buffer")
events := sse.Read(response, &sse.ReadConfig{MaxEventSize: 3})
var err error
events(func(_ sse.Event, e error) bool { err = e; return err == nil })
tests.Expect(t, err != nil, "should fail because of too small buffer")
})

t.Run("Break", func(t *testing.T) {
events, errf := sse.Read(strings.NewReader("id: a\n\nid: b\n\nid: c\n"), nil) // also test EOF edge case
events := sse.Read(strings.NewReader("id: a\n\nid: b\n\nid: c\n"), nil) // also test EOF edge case

var recv []sse.Event
events(func(e sse.Event) bool {
events(func(e sse.Event, err error) bool {
tests.Equal(t, err, nil, "unexpected error")
recv = append(recv, e)
return len(recv) < 2
})

tests.Equal(t, errf(), nil, "unexpected error")

expected := []sse.Event{{LastEventID: "a"}, {LastEventID: "b"}}
tests.DeepEqual(t, recv, expected, "iterator didn't stop")

// Cover break check on EOF edge case
// NOTE(tmaxmax): Should also test this with EOF return when possible.
sse.Read(strings.NewReader("data: x\n"), nil)(func(e sse.Event, err error) bool {
tests.Equal(t, err, nil, "unexpected error")
tests.Equal(t, e, sse.Event{Data: "x"}, "unexpected event")
return false
})
})
}

0 comments on commit b08a688

Please sign in to comment.