Skip to content

Commit

Permalink
Use sync.Map for subscriptions map (#2422)
Browse files Browse the repository at this point in the history
* Use sync.Map for subscriptions map

* Add tests for Unsubscribe
  • Loading branch information
weiihann authored Feb 5, 2025
1 parent a1b2ff1 commit c646d21
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 32 deletions.
5 changes: 1 addition & 4 deletions rpc/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,9 @@ func (h *Handler) Events(args EventsArg) (*EventsChunk, *jsonrpc.Error) {
return &EventsChunk{Events: emittedEvents, ContinuationToken: cTokenStr}, nil
}

// unsubscribe assumes h.mu is unlocked. It releases all subscription resources.
func (h *Handler) unsubscribe(sub *subscription, id uint64) {
sub.cancel()
h.mu.Lock()
delete(h.subscriptions, id)
h.mu.Unlock()
h.subscriptions.Delete(id)
}

func setEventFilterRange(filter blockchain.EventFilterer, fromID, toID *BlockID, latestHeight uint64) error {
Expand Down
20 changes: 10 additions & 10 deletions rpc/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,7 @@ type Handler struct {
l1Heads *feed.Feed[*core.L1Head]

idgen func() uint64
mu stdsync.Mutex // protects subscriptions.
subscriptions map[uint64]*subscription
subscriptions stdsync.Map // map[uint64]*subscription

blockTraceCache *lru.Cache[traceCacheKey, []TracedBlockTransaction]

Expand Down Expand Up @@ -136,12 +135,11 @@ func New(bcReader blockchain.Reader, syncReader sync.Reader, virtualMachine vm.V
}
return n
},
version: version,
newHeads: feed.New[*core.Header](),
reorgs: feed.New[*sync.ReorgBlockRange](),
pendingTxs: feed.New[[]core.Transaction](),
l1Heads: feed.New[*core.L1Head](),
subscriptions: make(map[uint64]*subscription),
version: version,
newHeads: feed.New[*core.Header](),
reorgs: feed.New[*sync.ReorgBlockRange](),
pendingTxs: feed.New[[]core.Transaction](),
l1Heads: feed.New[*core.L1Head](),

blockTraceCache: lru.NewCache[traceCacheKey, []TracedBlockTransaction](traceCacheSize),
filterLimit: math.MaxUint,
Expand Down Expand Up @@ -195,9 +193,11 @@ func (h *Handler) Run(ctx context.Context) error {
feed.Tee(l1HeadsSub, h.l1Heads)

<-ctx.Done()
for _, sub := range h.subscriptions {
h.subscriptions.Range(func(key, value any) bool {
sub := value.(*subscription)
sub.wg.Wait()
}
return true
})
return nil
}

Expand Down
33 changes: 15 additions & 18 deletions rpc/subscriptions.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,7 @@ func (h *Handler) SubscribeEvents(ctx context.Context, fromAddr *felt.Felt, keys
cancel: subscriptionCtxCancel,
conn: w,
}
h.mu.Lock()
h.subscriptions[id] = sub
h.mu.Unlock()
h.subscriptions.Store(id, sub)

headerSub := h.newHeads.Subscribe()
reorgSub := h.reorgs.Subscribe() // as per the spec, reorgs are also sent in the events subscription
Expand Down Expand Up @@ -155,9 +153,7 @@ func (h *Handler) SubscribeTransactionStatus(ctx context.Context, txHash felt.Fe
cancel: subscriptionCtxCancel,
conn: w,
}
h.mu.Lock()
h.subscriptions[id] = sub
h.mu.Unlock()
h.subscriptions.Store(id, sub)

l2HeadSub := h.newHeads.Subscribe()
l1HeadSub := h.l1Heads.Subscribe()
Expand Down Expand Up @@ -351,9 +347,7 @@ func (h *Handler) SubscribeNewHeads(ctx context.Context, blockID *BlockID) (*Sub
cancel: subscriptionCtxCancel,
conn: w,
}
h.mu.Lock()
h.subscriptions[id] = sub
h.mu.Unlock()
h.subscriptions.Store(id, sub)

headerSub := h.newHeads.Subscribe()
reorgSub := h.reorgs.Subscribe() // as per the spec, reorgs are also sent in the new heads subscription
Expand Down Expand Up @@ -406,9 +400,7 @@ func (h *Handler) SubscribePendingTxs(ctx context.Context, getDetails *bool, sen
cancel: subscriptionCtxCancel,
conn: w,
}
h.mu.Lock()
h.subscriptions[id] = sub
h.mu.Unlock()
h.subscriptions.Store(id, sub)

pendingTxsSub := h.pendingTxs.Subscribe()
sub.wg.Go(func() {
Expand Down Expand Up @@ -651,14 +643,19 @@ func (h *Handler) Unsubscribe(ctx context.Context, id uint64) (bool, *jsonrpc.Er
if !ok {
return false, jsonrpc.Err(jsonrpc.MethodNotFound, nil)
}
h.mu.Lock()
sub, ok := h.subscriptions[id]
h.mu.Unlock() // Don't defer since h.unsubscribe acquires the lock.
if !ok || !sub.conn.Equal(w) {
sub, ok := h.subscriptions.Load(id)
if !ok {
return false, ErrInvalidSubscriptionID
}
sub.cancel()
sub.wg.Wait() // Let the subscription finish before responding.

subs := sub.(*subscription)
if !subs.conn.Equal(w) {
return false, ErrInvalidSubscriptionID
}

subs.cancel()
subs.wg.Wait() // Let the subscription finish before responding.
h.subscriptions.Delete(id)
return true, nil
}

Expand Down
101 changes: 101 additions & 0 deletions rpc/subscriptions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -963,6 +963,107 @@ func TestSubscribePendingTxs(t *testing.T) {
})
}

func TestUnsubscribe(t *testing.T) {
log := utils.NewNopZapLogger()

t.Run("error when no connection in context", func(t *testing.T) {
mockCtrl := gomock.NewController(t)
t.Cleanup(mockCtrl.Finish)

mockChain := mocks.NewMockReader(mockCtrl)
mockSyncer := mocks.NewMockSyncReader(mockCtrl)
handler := New(mockChain, mockSyncer, nil, "", log)

success, rpcErr := handler.Unsubscribe(context.Background(), 1)
assert.False(t, success)
assert.Equal(t, jsonrpc.Err(jsonrpc.MethodNotFound, nil), rpcErr)
})

t.Run("error when subscription ID doesn't exist", func(t *testing.T) {
mockCtrl := gomock.NewController(t)
t.Cleanup(mockCtrl.Finish)

mockChain := mocks.NewMockReader(mockCtrl)
mockSyncer := mocks.NewMockSyncReader(mockCtrl)
handler := New(mockChain, mockSyncer, nil, "", log)

serverConn, _ := net.Pipe()
t.Cleanup(func() {
require.NoError(t, serverConn.Close())
})

ctx := context.WithValue(context.Background(), jsonrpc.ConnKey{}, &fakeConn{w: serverConn})
success, rpcErr := handler.Unsubscribe(ctx, 999)
assert.False(t, success)
assert.Equal(t, ErrInvalidSubscriptionID, rpcErr)
})

t.Run("return false when connection doesn't match", func(t *testing.T) {
mockCtrl := gomock.NewController(t)
t.Cleanup(mockCtrl.Finish)

mockChain := mocks.NewMockReader(mockCtrl)
mockSyncer := mocks.NewMockSyncReader(mockCtrl)
handler := New(mockChain, mockSyncer, nil, "", log)

// Create original subscription
serverConn1, _ := net.Pipe()
t.Cleanup(func() {
require.NoError(t, serverConn1.Close())
})

subCtx := context.WithValue(context.Background(), jsonrpc.ConnKey{}, &fakeConn{w: serverConn1})
_, subscriptionCtxCancel := context.WithCancel(subCtx)
sub := &subscription{
cancel: subscriptionCtxCancel,
conn: &fakeConn{w: serverConn1},
}
handler.subscriptions.Store(uint64(1), sub)

// Try to unsubscribe with different connection
serverConn2, _ := net.Pipe()
t.Cleanup(func() {
require.NoError(t, serverConn2.Close())
})

unsubCtx := context.WithValue(context.Background(), jsonrpc.ConnKey{}, &fakeConn{w: serverConn2})
success, rpcErr := handler.Unsubscribe(unsubCtx, 1)
assert.False(t, success)
assert.NotNil(t, rpcErr)
})

t.Run("successful unsubscribe", func(t *testing.T) {
mockCtrl := gomock.NewController(t)
t.Cleanup(mockCtrl.Finish)

mockChain := mocks.NewMockReader(mockCtrl)
mockSyncer := mocks.NewMockSyncReader(mockCtrl)
handler := New(mockChain, mockSyncer, nil, "", log)

serverConn, _ := net.Pipe()
t.Cleanup(func() {
require.NoError(t, serverConn.Close())
})

conn := &fakeConn{w: serverConn}
subCtx := context.WithValue(context.Background(), jsonrpc.ConnKey{}, conn)
_, subscriptionCtxCancel := context.WithCancel(subCtx)
sub := &subscription{
cancel: subscriptionCtxCancel,
conn: conn,
}
handler.subscriptions.Store(uint64(1), sub)

success, rpcErr := handler.Unsubscribe(subCtx, 1)
assert.True(t, success)
assert.Nil(t, rpcErr)

// Verify subscription was removed
_, exists := handler.subscriptions.Load(uint64(1))
assert.False(t, exists)
})
}

func createWsConn(t *testing.T, ctx context.Context, server *jsonrpc.Server) *websocket.Conn {
ws := jsonrpc.NewWebsocket(server, nil, utils.NewNopZapLogger())
httpSrv := httptest.NewServer(ws)
Expand Down

0 comments on commit c646d21

Please sign in to comment.