From f8ec2bd3c7b4cfeeef2fa8d5d756ce924f889991 Mon Sep 17 00:00:00 2001 From: Haiko Schol Date: Thu, 16 Jan 2025 18:09:07 +0700 Subject: [PATCH 01/10] feat(parachain): send requests from subsystems --- dot/network/service.go | 2 +- dot/parachain/collator-protocol/mocks_test.go | 27 +-- .../collator-protocol/validator_side.go | 4 +- dot/parachain/network-bridge/interface.go | 2 +- .../messages}/chunk_fetching.go | 14 +- .../messages}/chunk_fetching_test.go | 2 +- .../messages/request_response_protocols.go | 68 +++++++ .../messages/tx_overseer_messages.go | 13 ++ dot/parachain/network-bridge/sender.go | 35 +++- dot/parachain/network-bridge/sender_test.go | 191 ++++++++++++++++++ dot/parachain/network_protocols.go | 36 ---- dot/parachain/overseer/overseer.go | 2 +- dot/parachain/service.go | 2 +- 13 files changed, 340 insertions(+), 58 deletions(-) rename dot/parachain/{ => network-bridge/messages}/chunk_fetching.go (87%) rename dot/parachain/{ => network-bridge/messages}/chunk_fetching_test.go (99%) create mode 100644 dot/parachain/network-bridge/messages/request_response_protocols.go create mode 100644 dot/parachain/network-bridge/sender_test.go diff --git a/dot/network/service.go b/dot/network/service.go index 738f747378..e662cca5ec 100644 --- a/dot/network/service.go +++ b/dot/network/service.go @@ -593,7 +593,7 @@ func (s *Service) SendMessage(to peer.ID, msg NotificationsMessage) error { } func (s *Service) GetRequestResponseProtocol(subprotocol string, requestTimeout time.Duration, - maxResponseSize uint64) *RequestResponseProtocol { + maxResponseSize uint64) RequestMaker { protocolID := s.host.protocolID + protocol.ID(subprotocol) return &RequestResponseProtocol{ diff --git a/dot/parachain/collator-protocol/mocks_test.go b/dot/parachain/collator-protocol/mocks_test.go index cfc5504e56..1ae610cc29 100644 --- a/dot/parachain/collator-protocol/mocks_test.go +++ b/dot/parachain/collator-protocol/mocks_test.go @@ -23,6 +23,7 @@ import ( type MockNetwork struct { ctrl *gomock.Controller recorder *MockNetworkMockRecorder + isgomock struct{} } // MockNetworkMockRecorder is the mock recorder for MockNetwork. @@ -43,41 +44,41 @@ func (m *MockNetwork) EXPECT() *MockNetworkMockRecorder { } // GetRequestResponseProtocol mocks base method. -func (m *MockNetwork) GetRequestResponseProtocol(arg0 string, arg1 time.Duration, arg2 uint64) *network.RequestResponseProtocol { +func (m *MockNetwork) GetRequestResponseProtocol(subprotocol string, requestTimeout time.Duration, maxResponseSize uint64) network.RequestMaker { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetRequestResponseProtocol", arg0, arg1, arg2) - ret0, _ := ret[0].(*network.RequestResponseProtocol) + ret := m.ctrl.Call(m, "GetRequestResponseProtocol", subprotocol, requestTimeout, maxResponseSize) + ret0, _ := ret[0].(network.RequestMaker) return ret0 } // GetRequestResponseProtocol indicates an expected call of GetRequestResponseProtocol. -func (mr *MockNetworkMockRecorder) GetRequestResponseProtocol(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockNetworkMockRecorder) GetRequestResponseProtocol(subprotocol, requestTimeout, maxResponseSize any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRequestResponseProtocol", reflect.TypeOf((*MockNetwork)(nil).GetRequestResponseProtocol), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRequestResponseProtocol", reflect.TypeOf((*MockNetwork)(nil).GetRequestResponseProtocol), subprotocol, requestTimeout, maxResponseSize) } // GossipMessage mocks base method. -func (m *MockNetwork) GossipMessage(arg0 network.NotificationsMessage) { +func (m *MockNetwork) GossipMessage(msg network.NotificationsMessage) { m.ctrl.T.Helper() - m.ctrl.Call(m, "GossipMessage", arg0) + m.ctrl.Call(m, "GossipMessage", msg) } // GossipMessage indicates an expected call of GossipMessage. -func (mr *MockNetworkMockRecorder) GossipMessage(arg0 any) *gomock.Call { +func (mr *MockNetworkMockRecorder) GossipMessage(msg any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GossipMessage", reflect.TypeOf((*MockNetwork)(nil).GossipMessage), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GossipMessage", reflect.TypeOf((*MockNetwork)(nil).GossipMessage), msg) } // RegisterNotificationsProtocol mocks base method. -func (m *MockNetwork) RegisterNotificationsProtocol(arg0 protocol.ID, arg1 network.MessageType, arg2 func() (network.Handshake, error), arg3 func([]byte) (network.Handshake, error), arg4 func(peer.ID, network.Handshake) error, arg5 func([]byte) (network.NotificationsMessage, error), arg6 func(peer.ID, network.NotificationsMessage) (bool, error), arg7 func(peer.ID, network.NotificationsMessage), arg8 uint64) error { +func (m *MockNetwork) RegisterNotificationsProtocol(sub protocol.ID, messageID network.MessageType, handshakeGetter func() (network.Handshake, error), handshakeDecoder func([]byte) (network.Handshake, error), handshakeValidator func(peer.ID, network.Handshake) error, messageDecoder func([]byte) (network.NotificationsMessage, error), messageHandler func(peer.ID, network.NotificationsMessage) (bool, error), batchHandler func(peer.ID, network.NotificationsMessage), maxSize uint64) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "RegisterNotificationsProtocol", arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8) + ret := m.ctrl.Call(m, "RegisterNotificationsProtocol", sub, messageID, handshakeGetter, handshakeDecoder, handshakeValidator, messageDecoder, messageHandler, batchHandler, maxSize) ret0, _ := ret[0].(error) return ret0 } // RegisterNotificationsProtocol indicates an expected call of RegisterNotificationsProtocol. -func (mr *MockNetworkMockRecorder) RegisterNotificationsProtocol(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8 any) *gomock.Call { +func (mr *MockNetworkMockRecorder) RegisterNotificationsProtocol(sub, messageID, handshakeGetter, handshakeDecoder, handshakeValidator, messageDecoder, messageHandler, batchHandler, maxSize any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterNotificationsProtocol", reflect.TypeOf((*MockNetwork)(nil).RegisterNotificationsProtocol), arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterNotificationsProtocol", reflect.TypeOf((*MockNetwork)(nil).RegisterNotificationsProtocol), sub, messageID, handshakeGetter, handshakeDecoder, handshakeValidator, messageDecoder, messageHandler, batchHandler, maxSize) } diff --git a/dot/parachain/collator-protocol/validator_side.go b/dot/parachain/collator-protocol/validator_side.go index ddd5b576e6..ff6a972db0 100644 --- a/dot/parachain/collator-protocol/validator_side.go +++ b/dot/parachain/collator-protocol/validator_side.go @@ -559,7 +559,7 @@ type Network interface { maxSize uint64, ) error GetRequestResponseProtocol(subprotocol string, requestTimeout time.Duration, - maxResponseSize uint64) *network.RequestResponseProtocol + maxResponseSize uint64) network.RequestMaker } type CollationEvent struct { @@ -574,7 +574,7 @@ type CollatorProtocolValidatorSide struct { SubSystemToOverseer chan<- any unfetchedCollation chan UnfetchedCollation - collationFetchingReqResProtocol *network.RequestResponseProtocol + collationFetchingReqResProtocol network.RequestMaker fetchedCollations []parachaintypes.Collation // track all active collators and their data diff --git a/dot/parachain/network-bridge/interface.go b/dot/parachain/network-bridge/interface.go index d070a3c592..baa7d6b17d 100644 --- a/dot/parachain/network-bridge/interface.go +++ b/dot/parachain/network-bridge/interface.go @@ -27,7 +27,7 @@ type Network interface { maxSize uint64, ) error GetRequestResponseProtocol(subprotocol string, requestTimeout time.Duration, - maxResponseSize uint64) *network.RequestResponseProtocol + maxResponseSize uint64) network.RequestMaker ReportPeer(change peerset.ReputationChange, p peer.ID) DisconnectPeer(setID int, p peer.ID) GetNetworkEventsChannel() chan *network.NetworkEventInfo diff --git a/dot/parachain/chunk_fetching.go b/dot/parachain/network-bridge/messages/chunk_fetching.go similarity index 87% rename from dot/parachain/chunk_fetching.go rename to dot/parachain/network-bridge/messages/chunk_fetching.go index cd3086da39..aaa56f3185 100644 --- a/dot/parachain/chunk_fetching.go +++ b/dot/parachain/network-bridge/messages/chunk_fetching.go @@ -1,11 +1,13 @@ // Copyright 2023 ChainSafe Systems (ON) // SPDX-License-Identifier: LGPL-3.0-only -package parachain +package messages import ( "fmt" + "github.com/ChainSafe/gossamer/dot/network" + parachaintypes "github.com/ChainSafe/gossamer/dot/parachain/types" "github.com/ChainSafe/gossamer/pkg/scale" ) @@ -24,6 +26,16 @@ func (c ChunkFetchingRequest) Encode() ([]byte, error) { return scale.Marshal(c) } +// Protocol returns the sub-protocol ID for this message +func (c ChunkFetchingRequest) Protocol() ReqProtocolName { + return ChunkFetchingV1 +} + +// Response returns an instance of the response type for this message, for the purpose of decoding into it. +func (c ChunkFetchingRequest) Response() network.ResponseMessage { + return &ChunkFetchingResponse{} +} + type ChunkFetchingResponseValues interface { ChunkResponse | NoSuchChunk } diff --git a/dot/parachain/chunk_fetching_test.go b/dot/parachain/network-bridge/messages/chunk_fetching_test.go similarity index 99% rename from dot/parachain/chunk_fetching_test.go rename to dot/parachain/network-bridge/messages/chunk_fetching_test.go index a44d1962f4..5aa1b0ac77 100644 --- a/dot/parachain/chunk_fetching_test.go +++ b/dot/parachain/network-bridge/messages/chunk_fetching_test.go @@ -1,7 +1,7 @@ // Copyright 2023 ChainSafe Systems (ON) // SPDX-License-Identifier: LGPL-3.0-only -package parachain +package messages import ( "testing" diff --git a/dot/parachain/network-bridge/messages/request_response_protocols.go b/dot/parachain/network-bridge/messages/request_response_protocols.go new file mode 100644 index 0000000000..0cf16b51b3 --- /dev/null +++ b/dot/parachain/network-bridge/messages/request_response_protocols.go @@ -0,0 +1,68 @@ +package messages + +import ( + "github.com/ChainSafe/gossamer/dot/network" + "github.com/libp2p/go-libp2p/core/peer" +) + +type ReqProtocolName uint + +const ( + ChunkFetchingV1 ReqProtocolName = iota + CollationFetchingV1 + PoVFetchingV1 + AvailableDataFetchingV1 + StatementFetchingV1 + DisputeSendingV1 +) + +func (n ReqProtocolName) String() string { + switch n { + case ChunkFetchingV1: + return "req_chunk/1" + case CollationFetchingV1: + return "req_collation/1" + case PoVFetchingV1: + return "req_pov/1" + case AvailableDataFetchingV1: + return "req_available_data/1" + case StatementFetchingV1: + return "req_statement/1" + case DisputeSendingV1: + return "send_dispute/1" + default: + panic("unknown protocol") + } +} + +// ReqProtocolMessage is a network message that can be sent over a request response protocol. +type ReqProtocolMessage interface { + network.Message + // Response returns an instance of the response type for this message, for the purpose of decoding into it. + Response() network.ResponseMessage + Protocol() ReqProtocolName +} + +// ReqRespResult is the result of sending a request over a request response protocol. It contains either a response +// message or an error. +type ReqRespResult struct { + Response network.ResponseMessage + Error error +} + +// OutgoingRequest contains all data required to send a request over a request response protocol and receive the result. +type OutgoingRequest struct { + Recipient peer.ID // TODO use a type that can contain either a peer ID or an authority ID + Payload ReqProtocolMessage + Result chan ReqRespResult +} + +func NewOutgoingRequest(recipient peer.ID, payload ReqProtocolMessage) *OutgoingRequest { + result := make(chan ReqRespResult, 1) + + return &OutgoingRequest{ + Recipient: recipient, + Payload: payload, + Result: result, + } +} diff --git a/dot/parachain/network-bridge/messages/tx_overseer_messages.go b/dot/parachain/network-bridge/messages/tx_overseer_messages.go index e7d454689e..8c41e3613c 100644 --- a/dot/parachain/network-bridge/messages/tx_overseer_messages.go +++ b/dot/parachain/network-bridge/messages/tx_overseer_messages.go @@ -63,3 +63,16 @@ type ConnectToValidators struct { // authority discovery has Failed to resolve. Failed chan<- uint } + +type IfDisconnectedBehavior int + +const ( + TryConnect IfDisconnectedBehavior = iota + ImmediateError // TODO not implemented +) + +// SendRequests is a subsystem message for sending requests over a request response protocol. +type SendRequests struct { + Requests []*OutgoingRequest + IfDisconnected IfDisconnectedBehavior +} diff --git a/dot/parachain/network-bridge/sender.go b/dot/parachain/network-bridge/sender.go index e1a2df06e8..f38928d280 100644 --- a/dot/parachain/network-bridge/sender.go +++ b/dot/parachain/network-bridge/sender.go @@ -6,6 +6,7 @@ package networkbridge import ( "context" "fmt" + "time" "github.com/ChainSafe/gossamer/dot/network" networkbridgemessages "github.com/ChainSafe/gossamer/dot/parachain/network-bridge/messages" @@ -92,7 +93,9 @@ func (nbs *NetworkBridgeSender) processMessage(msg any) error { return fmt.Errorf("sending message: %w", err) } } - // TODO: add ConnectTOResolvedValidators, SendRequests + case networkbridgemessages.SendRequests: + nbs.sendRequests(msg.Requests, msg.IfDisconnected) + // TODO: add ConnectTOResolvedValidators case networkbridgemessages.ConnectToValidators: // TODO case networkbridgemessages.ReportPeer: @@ -104,3 +107,33 @@ func (nbs *NetworkBridgeSender) processMessage(msg any) error { return nil } + +const requestTimeout = 200 * time.Millisecond // TODO is this reasonable? + +// PoV is probably the largest message and is currently set at 5MB, but will likely be increased to 10MB in the future. +// see: https://github.com/paritytech/polkadot-sdk/issues/5334 +// Maybe message types should have a MaxSize() method instead of using the same value for all messages. +const maxResponseSize uint64 = 5 * 1024 * 1024 + +func (nbs *NetworkBridgeSender) sendRequests( + requests []*networkbridgemessages.OutgoingRequest, + ifDisconnected networkbridgemessages.IfDisconnectedBehavior, //nolint:unparam +) { + for _, request := range requests { + protoID := request.Payload.Protocol().String() + protocol := nbs.net.GetRequestResponseProtocol(protoID, requestTimeout, maxResponseSize) + response := request.Payload.Response() + result := networkbridgemessages.ReqRespResult{} + + // TODO This should probably be done on a goroutine. Unclear how to deal with cancellation/shutdown though. + err := protocol.Do(request.Recipient, request.Payload, response) + if err != nil { + result.Error = err + } else { + result.Response = response + } + + request.Result <- result + close(request.Result) + } +} diff --git a/dot/parachain/network-bridge/sender_test.go b/dot/parachain/network-bridge/sender_test.go new file mode 100644 index 0000000000..02d129aaee --- /dev/null +++ b/dot/parachain/network-bridge/sender_test.go @@ -0,0 +1,191 @@ +package networkbridge + +import ( + "errors" + "testing" + "time" + + "github.com/ChainSafe/gossamer/dot/network" + networkbridgemessages "github.com/ChainSafe/gossamer/dot/parachain/network-bridge/messages" + parachaintypes "github.com/ChainSafe/gossamer/dot/parachain/types" + "github.com/ChainSafe/gossamer/dot/peerset" + "github.com/ChainSafe/gossamer/lib/common" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/protocol" + "github.com/stretchr/testify/require" +) + +func TestSendRequests(t *testing.T) { + t.Run("request_succeeds", func(t *testing.T) { + request := makeOutgoingRequest(t) + response := &networkbridgemessages.ChunkFetchingResponse{} + expectedValue := networkbridgemessages.NoSuchChunk{} + require.NoError(t, response.SetValue(expectedValue)) + + nbs := setUpNetworkBridgeSender(t, response, nil, nil) + + sendRequests := networkbridgemessages.SendRequests{ + Requests: []*networkbridgemessages.OutgoingRequest{request}, + IfDisconnected: networkbridgemessages.TryConnect, + } + + err := nbs.processMessage(sendRequests) + require.NoError(t, err) + + result := <-request.Result + require.NoError(t, result.Error) + + cfResponse, ok := result.Response.(*networkbridgemessages.ChunkFetchingResponse) + require.True(t, ok) + + actualValue, err := cfResponse.Value() + require.NoError(t, err) + + require.Equal(t, expectedValue, actualValue) + requireClosed(t, request.Result) + }) + + t.Run("request_fails", func(t *testing.T) { + reqErr := errors.New("timeout") + request := makeOutgoingRequest(t) + + nbs := setUpNetworkBridgeSender(t, nil, nil, reqErr) + + sendRequests := networkbridgemessages.SendRequests{ + Requests: []*networkbridgemessages.OutgoingRequest{request}, + IfDisconnected: networkbridgemessages.TryConnect, + } + + err := nbs.processMessage(sendRequests) + require.NoError(t, err) + + result := <-request.Result + require.Equal(t, reqErr, result.Error) + require.Nil(t, result.Response) + requireClosed(t, request.Result) + }) + + t.Run("decoding_fails", func(t *testing.T) { + request := makeOutgoingRequest(t) + rawResponse := []byte("an invalid network response message") + + nbs := setUpNetworkBridgeSender(t, nil, rawResponse, nil) + + sendRequests := networkbridgemessages.SendRequests{ + Requests: []*networkbridgemessages.OutgoingRequest{request}, + IfDisconnected: networkbridgemessages.TryConnect, + } + + err := nbs.processMessage(sendRequests) + require.NoError(t, err) + + result := <-request.Result + require.Error(t, result.Error) + require.Nil(t, result.Response) + requireClosed(t, request.Result) + }) +} + +// We arbitrarily use a ChunkFetchingRequest since it does not matter for testing SendRequests handling. +func makeOutgoingRequest(t *testing.T) *networkbridgemessages.OutgoingRequest { + t.Helper() + + return networkbridgemessages.NewOutgoingRequest( + "recipient", + networkbridgemessages.ChunkFetchingRequest{ + CandidateHash: parachaintypes.CandidateHash{Value: common.Hash{1}}, + Index: 42, + }) +} + +// only one of response, rawResponse or reqErr should be non-nil +func setUpNetworkBridgeSender( + t *testing.T, + response network.ResponseMessage, + rawResponse []byte, + reqErr error, +) *NetworkBridgeSender { + t.Helper() + + if response != nil { + var err error + rawResponse, err = response.Encode() + require.NoError(t, err) + } + + netService := &mockNetworkService{ + rrp: &mockRequestResponseProtocol{ + rawResponse: rawResponse, + err: reqErr, + }, + } + + return RegisterSender(nil, netService) +} + +func requireClosed(t *testing.T, ch chan networkbridgemessages.ReqRespResult) { + select { + case <-ch: + default: + t.Error("channel is not closed") + } +} + +// TODO use gomock +type mockRequestResponseProtocol struct { + rawResponse []byte + err error +} + +func (m *mockRequestResponseProtocol) Do(to peer.ID, req network.Message, response network.ResponseMessage) error { + if m.err != nil { + return m.err + } + + if err := response.Decode(m.rawResponse); err != nil { + return err + } + return nil +} + +// TODO use gomock +type mockNetworkService struct { + rrp *mockRequestResponseProtocol +} + +func (m *mockNetworkService) GossipMessage(msg network.NotificationsMessage) {} + +func (m *mockNetworkService) SendMessage(to peer.ID, msg network.NotificationsMessage) error { + return nil +} + +func (m *mockNetworkService) RegisterNotificationsProtocol(sub protocol.ID, + messageID network.MessageType, + handshakeGetter network.HandshakeGetter, + handshakeDecoder network.HandshakeDecoder, + handshakeValidator network.HandshakeValidator, + messageDecoder network.MessageDecoder, + messageHandler network.NotificationsMessageHandler, + batchHandler network.NotificationsMessageBatchHandler, + maxSize uint64, +) error { + return nil +} + +func (m *mockNetworkService) GetRequestResponseProtocol( + subprotocol string, + requestTimeout time.Duration, + maxResponseSize uint64, +) network.RequestMaker { + return m.rrp +} + +func (m *mockNetworkService) ReportPeer(change peerset.ReputationChange, p peer.ID) {} + +func (m *mockNetworkService) DisconnectPeer(setID int, p peer.ID) {} + +func (m *mockNetworkService) GetNetworkEventsChannel() chan *network.NetworkEventInfo { + return nil +} + +func (m *mockNetworkService) FreeNetworkEventsChannel(ch chan *network.NetworkEventInfo) {} diff --git a/dot/parachain/network_protocols.go b/dot/parachain/network_protocols.go index 880febfbed..ee7434e295 100644 --- a/dot/parachain/network_protocols.go +++ b/dot/parachain/network_protocols.go @@ -10,17 +10,6 @@ import ( "github.com/ChainSafe/gossamer/lib/common" ) -type ReqProtocolName uint - -const ( - ChunkFetchingV1 ReqProtocolName = iota - CollationFetchingV1 - PoVFetchingV1 - AvailableDataFetchingV1 - StatementFetchingV1 - DisputeSendingV1 -) - type PeerSetProtocolName uint const ( @@ -28,31 +17,6 @@ const ( CollationProtocolName ) -func GenerateReqProtocolName(protocol ReqProtocolName, forkID string, GenesisHash common.Hash) string { - prefix := fmt.Sprintf("/%s", GenesisHash.String()) - - if forkID != "" { - prefix = fmt.Sprintf("%s/%s", prefix, forkID) - } - - switch protocol { - case ChunkFetchingV1: - return fmt.Sprintf("%s/req_chunk/1", prefix) - case CollationFetchingV1: - return fmt.Sprintf("%s/req_collation/1", prefix) - case PoVFetchingV1: - return fmt.Sprintf("%s/req_pov/1", prefix) - case AvailableDataFetchingV1: - return fmt.Sprintf("%s/req_available_data/1", prefix) - case StatementFetchingV1: - return fmt.Sprintf("%s/req_statement/1", prefix) - case DisputeSendingV1: - return fmt.Sprintf("%s/send_dispute/1", prefix) - default: - panic("unknown protocol") - } -} - func GeneratePeersetProtocolName(protocol PeerSetProtocolName, forkID string, GenesisHash common.Hash, version uint32, ) string { genesisHash := GenesisHash.String() diff --git a/dot/parachain/overseer/overseer.go b/dot/parachain/overseer/overseer.go index 626143c8e3..d0a8aea88e 100644 --- a/dot/parachain/overseer/overseer.go +++ b/dot/parachain/overseer/overseer.go @@ -135,7 +135,7 @@ func (o *OverseerSystem) processMessages() { case networkbridgemessages.DisconnectPeer, networkbridgemessages.ConnectToValidators, networkbridgemessages.ReportPeer, networkbridgemessages.SendCollationMessage, - networkbridgemessages.SendValidationMessage: + networkbridgemessages.SendValidationMessage, networkbridgemessages.SendRequests: subsystem = o.nameToSubsystem[parachaintypes.NetworkBridgeSender] case networkbridgeevents.Event[collatorprotocolmessages.CollationProtocol]: diff --git a/dot/parachain/service.go b/dot/parachain/service.go index 2e86939d2b..14ec9885c7 100644 --- a/dot/parachain/service.go +++ b/dot/parachain/service.go @@ -168,7 +168,7 @@ type Network interface { maxSize uint64, ) error GetRequestResponseProtocol(subprotocol string, requestTimeout time.Duration, - maxResponseSize uint64) *network.RequestResponseProtocol + maxResponseSize uint64) network.RequestMaker ReportPeer(change peerset.ReputationChange, p peer.ID) DisconnectPeer(setID int, p peer.ID) GetNetworkEventsChannel() chan *network.NetworkEventInfo From 93ca5ddfe92a7ec4d6c7b390dc22b22fb0d8088e Mon Sep 17 00:00:00 2001 From: Haiko Schol Date: Thu, 23 Jan 2025 14:17:52 +0700 Subject: [PATCH 02/10] use gomock --- .../network-bridge/mock_request_maker_test.go | 56 +++++++ .../network-bridge/mocks_generate_test.go | 7 + dot/parachain/network-bridge/mocks_test.go | 149 ++++++++++++++++++ dot/parachain/network-bridge/sender_test.go | 117 +++++--------- 4 files changed, 251 insertions(+), 78 deletions(-) create mode 100644 dot/parachain/network-bridge/mock_request_maker_test.go create mode 100644 dot/parachain/network-bridge/mocks_generate_test.go create mode 100644 dot/parachain/network-bridge/mocks_test.go diff --git a/dot/parachain/network-bridge/mock_request_maker_test.go b/dot/parachain/network-bridge/mock_request_maker_test.go new file mode 100644 index 0000000000..166f87e5cf --- /dev/null +++ b/dot/parachain/network-bridge/mock_request_maker_test.go @@ -0,0 +1,56 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/ChainSafe/gossamer/dot/network (interfaces: RequestMaker) +// +// Generated by this command: +// +// mockgen -destination=mock_request_maker_test.go -package=networkbridge github.com/ChainSafe/gossamer/dot/network RequestMaker +// + +// Package networkbridge is a generated GoMock package. +package networkbridge + +import ( + reflect "reflect" + + network "github.com/ChainSafe/gossamer/dot/network" + peer "github.com/libp2p/go-libp2p/core/peer" + gomock "go.uber.org/mock/gomock" +) + +// MockRequestMaker is a mock of RequestMaker interface. +type MockRequestMaker struct { + ctrl *gomock.Controller + recorder *MockRequestMakerMockRecorder + isgomock struct{} +} + +// MockRequestMakerMockRecorder is the mock recorder for MockRequestMaker. +type MockRequestMakerMockRecorder struct { + mock *MockRequestMaker +} + +// NewMockRequestMaker creates a new mock instance. +func NewMockRequestMaker(ctrl *gomock.Controller) *MockRequestMaker { + mock := &MockRequestMaker{ctrl: ctrl} + mock.recorder = &MockRequestMakerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockRequestMaker) EXPECT() *MockRequestMakerMockRecorder { + return m.recorder +} + +// Do mocks base method. +func (m *MockRequestMaker) Do(to peer.ID, req network.Message, res network.ResponseMessage) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Do", to, req, res) + ret0, _ := ret[0].(error) + return ret0 +} + +// Do indicates an expected call of Do. +func (mr *MockRequestMakerMockRecorder) Do(to, req, res any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Do", reflect.TypeOf((*MockRequestMaker)(nil).Do), to, req, res) +} diff --git a/dot/parachain/network-bridge/mocks_generate_test.go b/dot/parachain/network-bridge/mocks_generate_test.go new file mode 100644 index 0000000000..a97b34b31b --- /dev/null +++ b/dot/parachain/network-bridge/mocks_generate_test.go @@ -0,0 +1,7 @@ +// Copyright 2023 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package networkbridge + +//go:generate mockgen -destination=mocks_test.go -package=$GOPACKAGE . Network +//go:generate mockgen -destination=mock_request_maker_test.go -package=$GOPACKAGE github.com/ChainSafe/gossamer/dot/network RequestMaker diff --git a/dot/parachain/network-bridge/mocks_test.go b/dot/parachain/network-bridge/mocks_test.go new file mode 100644 index 0000000000..ae19f2dcef --- /dev/null +++ b/dot/parachain/network-bridge/mocks_test.go @@ -0,0 +1,149 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/ChainSafe/gossamer/dot/parachain/network-bridge (interfaces: Network) +// +// Generated by this command: +// +// mockgen -destination=mocks_test.go -package=networkbridge . Network +// + +// Package networkbridge is a generated GoMock package. +package networkbridge + +import ( + reflect "reflect" + time "time" + + network "github.com/ChainSafe/gossamer/dot/network" + peerset "github.com/ChainSafe/gossamer/dot/peerset" + peer "github.com/libp2p/go-libp2p/core/peer" + protocol "github.com/libp2p/go-libp2p/core/protocol" + gomock "go.uber.org/mock/gomock" +) + +// MockNetwork is a mock of Network interface. +type MockNetwork struct { + ctrl *gomock.Controller + recorder *MockNetworkMockRecorder + isgomock struct{} +} + +// MockNetworkMockRecorder is the mock recorder for MockNetwork. +type MockNetworkMockRecorder struct { + mock *MockNetwork +} + +// NewMockNetwork creates a new mock instance. +func NewMockNetwork(ctrl *gomock.Controller) *MockNetwork { + mock := &MockNetwork{ctrl: ctrl} + mock.recorder = &MockNetworkMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockNetwork) EXPECT() *MockNetworkMockRecorder { + return m.recorder +} + +// DisconnectPeer mocks base method. +func (m *MockNetwork) DisconnectPeer(setID int, p peer.ID) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DisconnectPeer", setID, p) +} + +// DisconnectPeer indicates an expected call of DisconnectPeer. +func (mr *MockNetworkMockRecorder) DisconnectPeer(setID, p any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectPeer", reflect.TypeOf((*MockNetwork)(nil).DisconnectPeer), setID, p) +} + +// FreeNetworkEventsChannel mocks base method. +func (m *MockNetwork) FreeNetworkEventsChannel(ch chan *network.NetworkEventInfo) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "FreeNetworkEventsChannel", ch) +} + +// FreeNetworkEventsChannel indicates an expected call of FreeNetworkEventsChannel. +func (mr *MockNetworkMockRecorder) FreeNetworkEventsChannel(ch any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FreeNetworkEventsChannel", reflect.TypeOf((*MockNetwork)(nil).FreeNetworkEventsChannel), ch) +} + +// GetNetworkEventsChannel mocks base method. +func (m *MockNetwork) GetNetworkEventsChannel() chan *network.NetworkEventInfo { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetNetworkEventsChannel") + ret0, _ := ret[0].(chan *network.NetworkEventInfo) + return ret0 +} + +// GetNetworkEventsChannel indicates an expected call of GetNetworkEventsChannel. +func (mr *MockNetworkMockRecorder) GetNetworkEventsChannel() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetNetworkEventsChannel", reflect.TypeOf((*MockNetwork)(nil).GetNetworkEventsChannel)) +} + +// GetRequestResponseProtocol mocks base method. +func (m *MockNetwork) GetRequestResponseProtocol(subprotocol string, requestTimeout time.Duration, maxResponseSize uint64) network.RequestMaker { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetRequestResponseProtocol", subprotocol, requestTimeout, maxResponseSize) + ret0, _ := ret[0].(network.RequestMaker) + return ret0 +} + +// GetRequestResponseProtocol indicates an expected call of GetRequestResponseProtocol. +func (mr *MockNetworkMockRecorder) GetRequestResponseProtocol(subprotocol, requestTimeout, maxResponseSize any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRequestResponseProtocol", reflect.TypeOf((*MockNetwork)(nil).GetRequestResponseProtocol), subprotocol, requestTimeout, maxResponseSize) +} + +// GossipMessage mocks base method. +func (m *MockNetwork) GossipMessage(msg network.NotificationsMessage) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "GossipMessage", msg) +} + +// GossipMessage indicates an expected call of GossipMessage. +func (mr *MockNetworkMockRecorder) GossipMessage(msg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GossipMessage", reflect.TypeOf((*MockNetwork)(nil).GossipMessage), msg) +} + +// RegisterNotificationsProtocol mocks base method. +func (m *MockNetwork) RegisterNotificationsProtocol(sub protocol.ID, messageID network.MessageType, handshakeGetter func() (network.Handshake, error), handshakeDecoder func([]byte) (network.Handshake, error), handshakeValidator func(peer.ID, network.Handshake) error, messageDecoder func([]byte) (network.NotificationsMessage, error), messageHandler func(peer.ID, network.NotificationsMessage) (bool, error), batchHandler func(peer.ID, network.NotificationsMessage), maxSize uint64) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RegisterNotificationsProtocol", sub, messageID, handshakeGetter, handshakeDecoder, handshakeValidator, messageDecoder, messageHandler, batchHandler, maxSize) + ret0, _ := ret[0].(error) + return ret0 +} + +// RegisterNotificationsProtocol indicates an expected call of RegisterNotificationsProtocol. +func (mr *MockNetworkMockRecorder) RegisterNotificationsProtocol(sub, messageID, handshakeGetter, handshakeDecoder, handshakeValidator, messageDecoder, messageHandler, batchHandler, maxSize any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterNotificationsProtocol", reflect.TypeOf((*MockNetwork)(nil).RegisterNotificationsProtocol), sub, messageID, handshakeGetter, handshakeDecoder, handshakeValidator, messageDecoder, messageHandler, batchHandler, maxSize) +} + +// ReportPeer mocks base method. +func (m *MockNetwork) ReportPeer(change peerset.ReputationChange, p peer.ID) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ReportPeer", change, p) +} + +// ReportPeer indicates an expected call of ReportPeer. +func (mr *MockNetworkMockRecorder) ReportPeer(change, p any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReportPeer", reflect.TypeOf((*MockNetwork)(nil).ReportPeer), change, p) +} + +// SendMessage mocks base method. +func (m *MockNetwork) SendMessage(to peer.ID, msg network.NotificationsMessage) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SendMessage", to, msg) + ret0, _ := ret[0].(error) + return ret0 +} + +// SendMessage indicates an expected call of SendMessage. +func (mr *MockNetworkMockRecorder) SendMessage(to, msg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendMessage", reflect.TypeOf((*MockNetwork)(nil).SendMessage), to, msg) +} diff --git a/dot/parachain/network-bridge/sender_test.go b/dot/parachain/network-bridge/sender_test.go index 02d129aaee..0dc8e4de08 100644 --- a/dot/parachain/network-bridge/sender_test.go +++ b/dot/parachain/network-bridge/sender_test.go @@ -2,27 +2,27 @@ package networkbridge import ( "errors" - "testing" - "time" - "github.com/ChainSafe/gossamer/dot/network" networkbridgemessages "github.com/ChainSafe/gossamer/dot/parachain/network-bridge/messages" parachaintypes "github.com/ChainSafe/gossamer/dot/parachain/types" - "github.com/ChainSafe/gossamer/dot/peerset" "github.com/ChainSafe/gossamer/lib/common" "github.com/libp2p/go-libp2p/core/peer" - "github.com/libp2p/go-libp2p/core/protocol" "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + "testing" ) func TestSendRequests(t *testing.T) { t.Run("request_succeeds", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + request := makeOutgoingRequest(t) response := &networkbridgemessages.ChunkFetchingResponse{} expectedValue := networkbridgemessages.NoSuchChunk{} require.NoError(t, response.SetValue(expectedValue)) - nbs := setUpNetworkBridgeSender(t, response, nil, nil) + nbs := setUpNetworkBridgeSender(t, ctrl, request, response, nil, nil) sendRequests := networkbridgemessages.SendRequests{ Requests: []*networkbridgemessages.OutgoingRequest{request}, @@ -46,10 +46,13 @@ func TestSendRequests(t *testing.T) { }) t.Run("request_fails", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + reqErr := errors.New("timeout") request := makeOutgoingRequest(t) - nbs := setUpNetworkBridgeSender(t, nil, nil, reqErr) + nbs := setUpNetworkBridgeSender(t, ctrl, request, nil, nil, reqErr) sendRequests := networkbridgemessages.SendRequests{ Requests: []*networkbridgemessages.OutgoingRequest{request}, @@ -66,10 +69,13 @@ func TestSendRequests(t *testing.T) { }) t.Run("decoding_fails", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + request := makeOutgoingRequest(t) rawResponse := []byte("an invalid network response message") - nbs := setUpNetworkBridgeSender(t, nil, rawResponse, nil) + nbs := setUpNetworkBridgeSender(t, ctrl, request, nil, rawResponse, nil) sendRequests := networkbridgemessages.SendRequests{ Requests: []*networkbridgemessages.OutgoingRequest{request}, @@ -101,24 +107,38 @@ func makeOutgoingRequest(t *testing.T) *networkbridgemessages.OutgoingRequest { // only one of response, rawResponse or reqErr should be non-nil func setUpNetworkBridgeSender( t *testing.T, + ctrl *gomock.Controller, + request *networkbridgemessages.OutgoingRequest, response network.ResponseMessage, rawResponse []byte, reqErr error, ) *NetworkBridgeSender { t.Helper() - if response != nil { - var err error - rawResponse, err = response.Encode() - require.NoError(t, err) - } + reqMaker := NewMockRequestMaker(ctrl) + reqMaker.EXPECT(). + Do(request.Recipient, request.Payload, gomock.AssignableToTypeOf(request.Payload.Response())). + DoAndReturn(func(to peer.ID, req network.Message, res network.ResponseMessage) error { + if reqErr != nil { + return reqErr + } + + if response != nil { + var err error + rawResponse, err = response.Encode() + require.NoError(t, err) + } + + if err := res.Decode(rawResponse); err != nil { + return err + } + return nil + }) - netService := &mockNetworkService{ - rrp: &mockRequestResponseProtocol{ - rawResponse: rawResponse, - err: reqErr, - }, - } + netService := NewMockNetwork(ctrl) + netService.EXPECT(). + GetRequestResponseProtocol(request.Payload.Protocol().String(), gomock.Any(), gomock.Any()). + Return(reqMaker) return RegisterSender(nil, netService) } @@ -130,62 +150,3 @@ func requireClosed(t *testing.T, ch chan networkbridgemessages.ReqRespResult) { t.Error("channel is not closed") } } - -// TODO use gomock -type mockRequestResponseProtocol struct { - rawResponse []byte - err error -} - -func (m *mockRequestResponseProtocol) Do(to peer.ID, req network.Message, response network.ResponseMessage) error { - if m.err != nil { - return m.err - } - - if err := response.Decode(m.rawResponse); err != nil { - return err - } - return nil -} - -// TODO use gomock -type mockNetworkService struct { - rrp *mockRequestResponseProtocol -} - -func (m *mockNetworkService) GossipMessage(msg network.NotificationsMessage) {} - -func (m *mockNetworkService) SendMessage(to peer.ID, msg network.NotificationsMessage) error { - return nil -} - -func (m *mockNetworkService) RegisterNotificationsProtocol(sub protocol.ID, - messageID network.MessageType, - handshakeGetter network.HandshakeGetter, - handshakeDecoder network.HandshakeDecoder, - handshakeValidator network.HandshakeValidator, - messageDecoder network.MessageDecoder, - messageHandler network.NotificationsMessageHandler, - batchHandler network.NotificationsMessageBatchHandler, - maxSize uint64, -) error { - return nil -} - -func (m *mockNetworkService) GetRequestResponseProtocol( - subprotocol string, - requestTimeout time.Duration, - maxResponseSize uint64, -) network.RequestMaker { - return m.rrp -} - -func (m *mockNetworkService) ReportPeer(change peerset.ReputationChange, p peer.ID) {} - -func (m *mockNetworkService) DisconnectPeer(setID int, p peer.ID) {} - -func (m *mockNetworkService) GetNetworkEventsChannel() chan *network.NetworkEventInfo { - return nil -} - -func (m *mockNetworkService) FreeNetworkEventsChannel(ch chan *network.NetworkEventInfo) {} From 15c72529458f4dd260f3c3b36546fa4d406c8b73 Mon Sep 17 00:00:00 2001 From: Haiko Schol Date: Thu, 23 Jan 2025 15:17:00 +0700 Subject: [PATCH 03/10] implement cancellation in OutgoingRequest --- .../messages/request_response_protocols.go | 24 +++++ dot/parachain/network-bridge/sender.go | 38 +++++--- dot/parachain/network-bridge/sender_test.go | 87 ++++++++++++++----- 3 files changed, 114 insertions(+), 35 deletions(-) diff --git a/dot/parachain/network-bridge/messages/request_response_protocols.go b/dot/parachain/network-bridge/messages/request_response_protocols.go index 0cf16b51b3..bb13444960 100644 --- a/dot/parachain/network-bridge/messages/request_response_protocols.go +++ b/dot/parachain/network-bridge/messages/request_response_protocols.go @@ -1,6 +1,8 @@ package messages import ( + "context" + "github.com/ChainSafe/gossamer/dot/network" "github.com/libp2p/go-libp2p/core/peer" ) @@ -55,14 +57,36 @@ type OutgoingRequest struct { Recipient peer.ID // TODO use a type that can contain either a peer ID or an authority ID Payload ReqProtocolMessage Result chan ReqRespResult + + ctx context.Context + cancel context.CancelFunc +} + +// Done returns a channel that is closed when the request is cancelled. +func (or *OutgoingRequest) Done() <-chan struct{} { + return or.ctx.Done() +} + +// Cancel cancels the request. +func (or *OutgoingRequest) Cancel() { + or.cancel() +} + +// IsCancelled returns true if the request has been cancelled. +func (or *OutgoingRequest) IsCancelled() bool { + return or.ctx.Err() != nil } +// NewOutgoingRequest creates a new outgoing request. func NewOutgoingRequest(recipient peer.ID, payload ReqProtocolMessage) *OutgoingRequest { result := make(chan ReqRespResult, 1) + ctx, cancel := context.WithCancel(context.Background()) return &OutgoingRequest{ Recipient: recipient, Payload: payload, Result: result, + ctx: ctx, + cancel: cancel, } } diff --git a/dot/parachain/network-bridge/sender.go b/dot/parachain/network-bridge/sender.go index f38928d280..9f15848dbf 100644 --- a/dot/parachain/network-bridge/sender.go +++ b/dot/parachain/network-bridge/sender.go @@ -120,20 +120,34 @@ func (nbs *NetworkBridgeSender) sendRequests( ifDisconnected networkbridgemessages.IfDisconnectedBehavior, //nolint:unparam ) { for _, request := range requests { - protoID := request.Payload.Protocol().String() - protocol := nbs.net.GetRequestResponseProtocol(protoID, requestTimeout, maxResponseSize) - response := request.Payload.Response() - result := networkbridgemessages.ReqRespResult{} - - // TODO This should probably be done on a goroutine. Unclear how to deal with cancellation/shutdown though. - err := protocol.Do(request.Recipient, request.Payload, response) - if err != nil { - result.Error = err - } else { - result.Response = response + if request.IsCancelled() { + close(request.Result) + continue } - request.Result <- result + result := nbs.sendRequest(request) + + if !request.IsCancelled() { + request.Result <- result + } close(request.Result) } } + +func (nbs *NetworkBridgeSender) sendRequest( + request *networkbridgemessages.OutgoingRequest, +) networkbridgemessages.ReqRespResult { + protoID := request.Payload.Protocol().String() + protocol := nbs.net.GetRequestResponseProtocol(protoID, requestTimeout, maxResponseSize) + response := request.Payload.Response() + result := networkbridgemessages.ReqRespResult{} + + err := protocol.Do(request.Recipient, request.Payload, response) + if err != nil { + result.Error = err + } else { + result.Response = response + } + + return result +} diff --git a/dot/parachain/network-bridge/sender_test.go b/dot/parachain/network-bridge/sender_test.go index 0dc8e4de08..d3aff64be9 100644 --- a/dot/parachain/network-bridge/sender_test.go +++ b/dot/parachain/network-bridge/sender_test.go @@ -2,6 +2,8 @@ package networkbridge import ( "errors" + "testing" + "github.com/ChainSafe/gossamer/dot/network" networkbridgemessages "github.com/ChainSafe/gossamer/dot/parachain/network-bridge/messages" parachaintypes "github.com/ChainSafe/gossamer/dot/parachain/types" @@ -9,7 +11,6 @@ import ( "github.com/libp2p/go-libp2p/core/peer" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" - "testing" ) func TestSendRequests(t *testing.T) { @@ -90,6 +91,29 @@ func TestSendRequests(t *testing.T) { require.Nil(t, result.Response) requireClosed(t, request.Result) }) + + t.Run("cancel_request", func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + request := makeOutgoingRequest(t) + nbs := setUpNetworkBridgeSender(t, ctrl, request, nil, nil, nil) + + sendRequests := networkbridgemessages.SendRequests{ + Requests: []*networkbridgemessages.OutgoingRequest{request}, + IfDisconnected: networkbridgemessages.TryConnect, + } + + request.Cancel() + + err := nbs.processMessage(sendRequests) + require.NoError(t, err) + + result := <-request.Result + require.Nil(t, result.Response) + require.NoError(t, result.Error) + requireClosed(t, request.Result) + }) } // We arbitrarily use a ChunkFetchingRequest since it does not matter for testing SendRequests handling. @@ -104,7 +128,9 @@ func makeOutgoingRequest(t *testing.T) *networkbridgemessages.OutgoingRequest { }) } -// only one of response, rawResponse or reqErr should be non-nil +// Expect calls to Network.GetRequestResponseProtocol() and RequestMaker.Do() when only one of response, rawResponse or +// reqErr should be non-nil. +// Expect no calls Network.GetRequestResponseProtocol() RequestMaker.Do() when all three are nil. func setUpNetworkBridgeSender( t *testing.T, ctrl *gomock.Controller, @@ -115,30 +141,45 @@ func setUpNetworkBridgeSender( ) *NetworkBridgeSender { t.Helper() + expectCancellation := response == nil && rawResponse == nil && reqErr == nil reqMaker := NewMockRequestMaker(ctrl) - reqMaker.EXPECT(). - Do(request.Recipient, request.Payload, gomock.AssignableToTypeOf(request.Payload.Response())). - DoAndReturn(func(to peer.ID, req network.Message, res network.ResponseMessage) error { - if reqErr != nil { - return reqErr - } - - if response != nil { - var err error - rawResponse, err = response.Encode() - require.NoError(t, err) - } - - if err := res.Decode(rawResponse); err != nil { - return err - } - return nil - }) + + if expectCancellation { + reqMaker.EXPECT(). + Do(gomock.Any(), gomock.Any(), gomock.Any()). + Times(0) + } else { + reqMaker.EXPECT(). + Do(request.Recipient, request.Payload, gomock.AssignableToTypeOf(request.Payload.Response())). + DoAndReturn(func(to peer.ID, req network.Message, res network.ResponseMessage) error { + if reqErr != nil { + return reqErr + } + + if response != nil { + var err error + rawResponse, err = response.Encode() + require.NoError(t, err) + } + + if err := res.Decode(rawResponse); err != nil { + return err + } + return nil + }) + } netService := NewMockNetwork(ctrl) - netService.EXPECT(). - GetRequestResponseProtocol(request.Payload.Protocol().String(), gomock.Any(), gomock.Any()). - Return(reqMaker) + + if expectCancellation { + netService.EXPECT(). + GetRequestResponseProtocol(gomock.Any(), gomock.Any(), gomock.Any()). + Times(0) + } else { + netService.EXPECT(). + GetRequestResponseProtocol(request.Payload.Protocol().String(), gomock.Any(), gomock.Any()). + Return(reqMaker) + } return RegisterSender(nil, netService) } From 364de9343f97b3d8c1adaaba833e3ded16695e93 Mon Sep 17 00:00:00 2001 From: Haiko Schol Date: Thu, 23 Jan 2025 15:36:36 +0700 Subject: [PATCH 04/10] fix unrelated flaky tests --- .../prospective-parachains/fragment_chain_test.go | 9 +++++---- .../prospective_parachains_test.go | 10 +++++++--- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/dot/parachain/prospective-parachains/fragment_chain_test.go b/dot/parachain/prospective-parachains/fragment_chain_test.go index 7951565557..d08e44667a 100644 --- a/dot/parachain/prospective-parachains/fragment_chain_test.go +++ b/dot/parachain/prospective-parachains/fragment_chain_test.go @@ -916,7 +916,7 @@ func TestCandidateStorageMethods(t *testing.T) { possibleBackedCandidateHashes = append(possibleBackedCandidateHashes, entry.candidateHash) } - require.Equal(t, []parachaintypes.CandidateHash{candidateHash}, possibleBackedCandidateHashes) + require.Contains(t, possibleBackedCandidateHashes, candidateHash) // now mark it as backed storage.markBacked(candidateHash2) @@ -928,9 +928,10 @@ func TestCandidateStorageMethods(t *testing.T) { possibleBackedCandidateHashes = append(possibleBackedCandidateHashes, entry.candidateHash) } - require.Equal(t, []parachaintypes.CandidateHash{ - candidateHash, candidateHash2}, possibleBackedCandidateHashes) - + // The iterator returned by storage.possibleBackedParaChildren() takes values from a map. + // Therefore we must not assert on the order of elements in possibleBackedCandidateHashes. + require.Contains(t, possibleBackedCandidateHashes, candidateHash) + require.Contains(t, possibleBackedCandidateHashes, candidateHash2) }) }, }, diff --git a/dot/parachain/prospective-parachains/prospective_parachains_test.go b/dot/parachain/prospective-parachains/prospective_parachains_test.go index 806d6dfd17..89137d776e 100644 --- a/dot/parachain/prospective-parachains/prospective_parachains_test.go +++ b/dot/parachain/prospective-parachains/prospective_parachains_test.go @@ -361,10 +361,14 @@ func TestGetMinimumRelayParents(t *testing.T) { BlockNumber: 10, }, } - // Validate the results + result := <-sender - assert.Len(t, result, 2) - assert.Equal(t, expected, result) + assert.Len(t, result, len(expected)) + + // Validate the results without asserting on the order of ParaIDBlockNumber values. + for _, ex := range expected { + assert.Contains(t, result, ex) + } } // TestGetMinimumRelayParents_NoActiveLeaves ensures that getMinimumRelayParents From ba731e3eaf43badc35cf571bf721fbed086fc49f Mon Sep 17 00:00:00 2001 From: Haiko Schol Date: Thu, 23 Jan 2025 16:08:11 +0700 Subject: [PATCH 05/10] fix doc comment --- dot/parachain/network-bridge/sender_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dot/parachain/network-bridge/sender_test.go b/dot/parachain/network-bridge/sender_test.go index d3aff64be9..0f58a053cb 100644 --- a/dot/parachain/network-bridge/sender_test.go +++ b/dot/parachain/network-bridge/sender_test.go @@ -129,7 +129,7 @@ func makeOutgoingRequest(t *testing.T) *networkbridgemessages.OutgoingRequest { } // Expect calls to Network.GetRequestResponseProtocol() and RequestMaker.Do() when only one of response, rawResponse or -// reqErr should be non-nil. +// reqErr is non-nil. // Expect no calls Network.GetRequestResponseProtocol() RequestMaker.Do() when all three are nil. func setUpNetworkBridgeSender( t *testing.T, From 6923ef10679801e43eeb24c8ebcf1d13ccf14fd0 Mon Sep 17 00:00:00 2001 From: Haiko Schol Date: Thu, 23 Jan 2025 16:42:12 +0700 Subject: [PATCH 06/10] run SendRequests tests in parallel --- dot/parachain/network-bridge/sender_test.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/dot/parachain/network-bridge/sender_test.go b/dot/parachain/network-bridge/sender_test.go index 0f58a053cb..e1575a5950 100644 --- a/dot/parachain/network-bridge/sender_test.go +++ b/dot/parachain/network-bridge/sender_test.go @@ -14,7 +14,10 @@ import ( ) func TestSendRequests(t *testing.T) { + t.Parallel() + t.Run("request_succeeds", func(t *testing.T) { + t.Parallel() ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -47,6 +50,7 @@ func TestSendRequests(t *testing.T) { }) t.Run("request_fails", func(t *testing.T) { + t.Parallel() ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -70,6 +74,7 @@ func TestSendRequests(t *testing.T) { }) t.Run("decoding_fails", func(t *testing.T) { + t.Parallel() ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -93,6 +98,7 @@ func TestSendRequests(t *testing.T) { }) t.Run("cancel_request", func(t *testing.T) { + t.Parallel() ctrl := gomock.NewController(t) defer ctrl.Finish() From 9c326caa6cea4d12290c8ce32a3b123aa7a8a37f Mon Sep 17 00:00:00 2001 From: Haiko Schol Date: Thu, 23 Jan 2025 18:02:33 +0700 Subject: [PATCH 07/10] add/fix license headers --- .../network-bridge/messages/request_response_protocols.go | 3 +++ dot/parachain/network-bridge/mocks_generate_test.go | 2 +- dot/parachain/network-bridge/sender_test.go | 3 +++ 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/dot/parachain/network-bridge/messages/request_response_protocols.go b/dot/parachain/network-bridge/messages/request_response_protocols.go index bb13444960..820ec29fae 100644 --- a/dot/parachain/network-bridge/messages/request_response_protocols.go +++ b/dot/parachain/network-bridge/messages/request_response_protocols.go @@ -1,3 +1,6 @@ +// Copyright 2025 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + package messages import ( diff --git a/dot/parachain/network-bridge/mocks_generate_test.go b/dot/parachain/network-bridge/mocks_generate_test.go index a97b34b31b..9c1ac6f979 100644 --- a/dot/parachain/network-bridge/mocks_generate_test.go +++ b/dot/parachain/network-bridge/mocks_generate_test.go @@ -1,4 +1,4 @@ -// Copyright 2023 ChainSafe Systems (ON) +// Copyright 2025 ChainSafe Systems (ON) // SPDX-License-Identifier: LGPL-3.0-only package networkbridge diff --git a/dot/parachain/network-bridge/sender_test.go b/dot/parachain/network-bridge/sender_test.go index e1575a5950..d3beff5638 100644 --- a/dot/parachain/network-bridge/sender_test.go +++ b/dot/parachain/network-bridge/sender_test.go @@ -1,3 +1,6 @@ +// Copyright 2025 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + package networkbridge import ( From 6b7321ca69e882dfea73d60e3972ac26db139c2a Mon Sep 17 00:00:00 2001 From: Haiko Schol Date: Thu, 23 Jan 2025 22:38:43 +0700 Subject: [PATCH 08/10] ensure context inside OutgoingRequest is always cancelled --- dot/parachain/network-bridge/sender.go | 1 + 1 file changed, 1 insertion(+) diff --git a/dot/parachain/network-bridge/sender.go b/dot/parachain/network-bridge/sender.go index 9f15848dbf..66765a190b 100644 --- a/dot/parachain/network-bridge/sender.go +++ b/dot/parachain/network-bridge/sender.go @@ -129,6 +129,7 @@ func (nbs *NetworkBridgeSender) sendRequests( if !request.IsCancelled() { request.Result <- result + request.Cancel() // only called here to avoid resource leaks } close(request.Result) } From 9c6e11326dbf9a91c9a86bc0dd55b57a42df7560 Mon Sep 17 00:00:00 2001 From: Haiko Schol Date: Mon, 3 Feb 2025 15:30:03 +0700 Subject: [PATCH 09/10] increase request timeout --- dot/parachain/network-bridge/sender.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dot/parachain/network-bridge/sender.go b/dot/parachain/network-bridge/sender.go index 66765a190b..c702dc371e 100644 --- a/dot/parachain/network-bridge/sender.go +++ b/dot/parachain/network-bridge/sender.go @@ -108,7 +108,7 @@ func (nbs *NetworkBridgeSender) processMessage(msg any) error { return nil } -const requestTimeout = 200 * time.Millisecond // TODO is this reasonable? +const requestTimeout = 2 * time.Second // PoV is probably the largest message and is currently set at 5MB, but will likely be increased to 10MB in the future. // see: https://github.com/paritytech/polkadot-sdk/issues/5334 From 54efaf6803b03d0f5e72f08641362aedd978fe00 Mon Sep 17 00:00:00 2001 From: Haiko Schol Date: Mon, 3 Feb 2025 16:09:43 +0700 Subject: [PATCH 10/10] require response channel to be empty and closed --- dot/parachain/network-bridge/sender_test.go | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/dot/parachain/network-bridge/sender_test.go b/dot/parachain/network-bridge/sender_test.go index d3beff5638..b5d98e0f45 100644 --- a/dot/parachain/network-bridge/sender_test.go +++ b/dot/parachain/network-bridge/sender_test.go @@ -49,7 +49,7 @@ func TestSendRequests(t *testing.T) { require.NoError(t, err) require.Equal(t, expectedValue, actualValue) - requireClosed(t, request.Result) + requireEmptyAndClosed(t, request.Result) }) t.Run("request_fails", func(t *testing.T) { @@ -73,7 +73,7 @@ func TestSendRequests(t *testing.T) { result := <-request.Result require.Equal(t, reqErr, result.Error) require.Nil(t, result.Response) - requireClosed(t, request.Result) + requireEmptyAndClosed(t, request.Result) }) t.Run("decoding_fails", func(t *testing.T) { @@ -97,7 +97,7 @@ func TestSendRequests(t *testing.T) { result := <-request.Result require.Error(t, result.Error) require.Nil(t, result.Response) - requireClosed(t, request.Result) + requireEmptyAndClosed(t, request.Result) }) t.Run("cancel_request", func(t *testing.T) { @@ -121,7 +121,7 @@ func TestSendRequests(t *testing.T) { result := <-request.Result require.Nil(t, result.Response) require.NoError(t, result.Error) - requireClosed(t, request.Result) + requireEmptyAndClosed(t, request.Result) }) } @@ -193,9 +193,10 @@ func setUpNetworkBridgeSender( return RegisterSender(nil, netService) } -func requireClosed(t *testing.T, ch chan networkbridgemessages.ReqRespResult) { +func requireEmptyAndClosed(t *testing.T, ch chan networkbridgemessages.ReqRespResult) { select { - case <-ch: + case _, ok := <-ch: + require.False(t, ok, "channel was not empty") default: t.Error("channel is not closed") }