Skip to content

Commit

Permalink
chore: fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
hopeyen committed Feb 27, 2025
1 parent 33f0a5b commit 4607034
Show file tree
Hide file tree
Showing 12 changed files with 138 additions and 84 deletions.
2 changes: 2 additions & 0 deletions api/clients/retrieval_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ type retrievalClient struct {
func NewRetrievalClient(
logger logging.Logger,
chainState core.ChainState,
socketStateCache core.SocketStateCache,
assignmentCoordinator core.AssignmentCoordinator,
nodeClient NodeClient,
verifier encoding.Verifier,
Expand All @@ -73,6 +74,7 @@ func NewRetrievalClient(
return &retrievalClient{
logger: logger.With("component", "RetrievalClient"),
chainState: chainState,
socketStateCache: socketStateCache,
assignmentCoordinator: assignmentCoordinator,
nodeClient: nodeClient,
verifier: verifier,
Expand Down
9 changes: 6 additions & 3 deletions api/clients/retrieval_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,17 +80,20 @@ func setup(t *testing.T) {
t.Fatalf("failed to create new mocked chain data: %s", err)
}

socketStateCache, err := coremock.NewSocketStateCacheMock(nil)
if err != nil {
t.Fatalf("failed to create new mocked socket state cache: %s", err)
}

nodeClient = clientsmock.NewNodeClient()
coordinator = &core.StdAssignmentCoordinator{}
p, v, err := makeTestComponents()
if err != nil {
t.Fatal(err)
}
logger := testutils.GetLogger()
indexer = &indexermock.MockIndexer{}
indexer.On("Index").Return(nil).Once()

retrievalClient, err = clients.NewRetrievalClient(logger, chainState, coordinator, nodeClient, v, 2)
retrievalClient, err = clients.NewRetrievalClient(logger, chainState, socketStateCache, coordinator, nodeClient, v, 2)
if err != nil {
panic("failed to create a new retrieval client")
}
Expand Down
54 changes: 35 additions & 19 deletions api/clients/v2/retrieval_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"

"github.com/docker/go-units"

"github.com/Layr-Labs/eigenda/api/clients"
Expand Down Expand Up @@ -33,11 +34,12 @@ type RetrievalClient interface {
}

type retrievalClient struct {
logger logging.Logger
ethClient core.Reader
indexedChainState core.IndexedChainState
verifier encoding.Verifier
numConnections int
logger logging.Logger
ethClient core.Reader
chainState core.ChainState
socketStateCache core.SocketStateCache
verifier encoding.Verifier
numConnections int
}

var _ RetrievalClient = &retrievalClient{}
Expand All @@ -46,16 +48,18 @@ var _ RetrievalClient = &retrievalClient{}
func NewRetrievalClient(
logger logging.Logger,
ethClient core.Reader,
chainState core.IndexedChainState,
chainState core.ChainState,
socketStateCache core.SocketStateCache,
verifier encoding.Verifier,
numConnections int,
) RetrievalClient {
return &retrievalClient{
logger: logger.With("component", "RetrievalClient"),
ethClient: ethClient,
indexedChainState: chainState,
verifier: verifier,
numConnections: numConnections,
logger: logger.With("component", "RetrievalClient"),
ethClient: ethClient,
chainState: chainState,
socketStateCache: socketStateCache,
verifier: verifier,
numConnections: numConnections,
}
}

Expand All @@ -74,14 +78,23 @@ func (r *retrievalClient) GetBlob(
return nil, err
}

indexedOperatorState, err := r.indexedChainState.GetIndexedOperatorState(ctx, uint(referenceBlockNumber), []core.QuorumID{quorumID})
operatorState, err := r.chainState.GetOperatorState(ctx, uint(referenceBlockNumber), []core.QuorumID{quorumID})
if err != nil {
return nil, err
}
operators, ok := indexedOperatorState.Operators[quorumID]
operators, ok := operatorState.Operators[quorumID]
if !ok {
return nil, fmt.Errorf("no quorum with ID: %d", quorumID)
}
// grab all operators IDs
operatorIDs := make([]core.OperatorID, 0, len(operators))
for operatorID := range operators {
operatorIDs = append(operatorIDs, operatorID)
}
operatorSockets, err := r.socketStateCache.GetOperatorSockets(ctx, operatorIDs)
if err != nil {
return nil, err
}

blobVersions, err := r.ethClient.GetAllVersionedBlobParams(ctx)
if err != nil {
Expand All @@ -98,7 +111,7 @@ func (r *retrievalClient) GetBlob(
return nil, err
}

assignments, err := corev2.GetAssignments(indexedOperatorState.OperatorState, blobParam, quorumID)
assignments, err := corev2.GetAssignments(operatorState, blobParam, quorumID)
if err != nil {
return nil, errors.New("failed to get assignments")
}
Expand All @@ -107,10 +120,13 @@ func (r *retrievalClient) GetBlob(
chunksChan := make(chan clients.RetrievedChunks, len(operators))
pool := workerpool.New(r.numConnections)
for opID := range operators {
opID := opID
opInfo := indexedOperatorState.IndexedOperators[opID]
socket := operatorSockets[opID]
if socket == nil {
r.logger.Warn("no socket for operator", "operator", opID)
continue
}
pool.Submit(func() {
r.getChunksFromOperator(ctx, opID, opInfo, blobKey, quorumID, chunksChan)
r.getChunksFromOperator(ctx, opID, socket, blobKey, quorumID, chunksChan)
})
}

Expand Down Expand Up @@ -160,7 +176,7 @@ func (r *retrievalClient) GetBlob(
func (r *retrievalClient) getChunksFromOperator(
ctx context.Context,
opID core.OperatorID,
opInfo *core.IndexedOperatorInfo,
socket *core.OperatorSocket,
blobKey corev2.BlobKey,
quorumID core.QuorumID,
chunksChan chan clients.RetrievedChunks,
Expand All @@ -176,7 +192,7 @@ func (r *retrievalClient) getChunksFromOperator(
maxMessageSize := maxBlobSize*encodingRate + fudgeFactor

conn, err := grpc.NewClient(
core.OperatorSocket(opInfo.Socket).GetV2RetrievalSocket(),
socket.GetV2RetrievalSocket(),
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(maxMessageSize)),
)
Expand Down
8 changes: 6 additions & 2 deletions api/clients/v2/validator_payload_retriever.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,10 @@ func BuildValidatorPayloadRetriever(
}

chainState := eth.NewChainState(reader, ethClient)
indexedChainState := thegraph.MakeIndexedChainState(thegraphConfig, chainState, logger)
socketStateCache, err := eth.NewSocketStateCache(context.Background(), reader, logger)
if err != nil {
return nil, fmt.Errorf("new socket state cache: %w", err)
}

kzgVerifier, err := verifier.NewVerifier(&kzgConfig, nil)
if err != nil {
Expand All @@ -69,7 +72,8 @@ func BuildValidatorPayloadRetriever(
retrievalClient := NewRetrievalClient(
logger,
reader,
indexedChainState,
chainState,
socketStateCache,
kzgVerifier,
int(validatorPayloadRetrieverConfig.MaxConnectionCount))

Expand Down
57 changes: 57 additions & 0 deletions core/mock/socket_state_cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package mock

import (
"context"
"fmt"

"github.com/Layr-Labs/eigenda/core"
"github.com/stretchr/testify/mock"
)

type SocketStateCacheMock struct {
mock.Mock

OperatorSockets core.OperatorSockets
}

var _ core.SocketStateCache = (*SocketStateCacheMock)(nil)

func NewSocketStateCacheMock(operatorSockets core.OperatorSockets) (*SocketStateCacheMock, error) {
if operatorSockets == nil {
operatorSockets = make(map[core.OperatorID]*core.OperatorSocket)
}
return &SocketStateCacheMock{
OperatorSockets: operatorSockets,
}, nil
}

func (s *SocketStateCacheMock) GetOperatorSocket(ctx context.Context, operator core.OperatorID) (string, error) {
args := s.Called(ctx, operator)
return args.Get(0).(string), args.Error(1)
}

func (s *SocketStateCacheMock) GetOperatorSockets(ctx context.Context, operators []core.OperatorID) (core.OperatorSockets, error) {
args := s.Called(ctx, operators)
if args.Get(0) != nil {
return args.Get(0).(core.OperatorSockets), args.Error(1)
}

// If no mock expectation is set, generate deterministic sockets for each operator
sockets := make(map[core.OperatorID]*core.OperatorSocket)
for i, operator := range operators {
socket := generateSocketFromOperatorID(i, operator)
sockets[operator] = &socket
}
return sockets, nil
}

// generateSocketFromOperatorID creates a deterministic socket based on the operator ID
func generateSocketFromOperatorID(operatorIndex int, id core.OperatorID) core.OperatorSocket {
host := "0.0.0.0"
dispersalPort := fmt.Sprintf("3%03v", 2*operatorIndex)
retrievalPort := fmt.Sprintf("3%03v", 2*operatorIndex+1)
v2DispersalPort := fmt.Sprintf("3%03v", 2*operatorIndex+2)
v2RetrievalPort := fmt.Sprintf("3%03v", 2*operatorIndex+3)

return core.MakeOperatorSocket(host, dispersalPort, retrievalPort, v2DispersalPort, v2RetrievalPort)
}
19 changes: 4 additions & 15 deletions inabox/tests/integration_suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import (
rollupbindings "github.com/Layr-Labs/eigenda/contracts/bindings/MockRollup"
"github.com/Layr-Labs/eigenda/core"
"github.com/Layr-Labs/eigenda/core/eth"
"github.com/Layr-Labs/eigenda/core/thegraph"
corev2 "github.com/Layr-Labs/eigenda/core/v2"
"github.com/Layr-Labs/eigenda/encoding/kzg"
"github.com/Layr-Labs/eigenda/encoding/kzg/verifier"
Expand Down Expand Up @@ -202,31 +201,21 @@ func setupRetrievalClient(testConfig *deploy.Config) error {
return err
}

graphBackoff, err := time.ParseDuration(testConfig.Retriever.RETRIEVER_GRAPH_BACKOFF)
socketStateCache, err := eth.NewSocketStateCache(context.Background(), tx, logger)
if err != nil {
return err
}
maxRetries, err := strconv.Atoi(testConfig.Retriever.RETRIEVER_GRAPH_MAX_RETRIES)
if err != nil {
return err
}
ics := thegraph.MakeIndexedChainState(thegraph.Config{
Endpoint: testConfig.Retriever.RETRIEVER_GRAPH_URL,
PullInterval: graphBackoff,
MaxRetries: maxRetries,
}, cs, logger)

retrievalClient, err = clients.NewRetrievalClient(logger, ics, agn, nodeClient, v, 10)
retrievalClient, err = clients.NewRetrievalClient(logger, cs, socketStateCache, agn, nodeClient, v, 10)
if err != nil {
return err
}
chainReader, err := eth.NewReader(logger, ethClient, testConfig.Retriever.RETRIEVER_BLS_OPERATOR_STATE_RETRIVER, testConfig.Retriever.RETRIEVER_EIGENDA_SERVICE_MANAGER)
if err != nil {
return err
}
retrievalClientV2 = clientsv2.NewRetrievalClient(logger, chainReader, ics, v, 10)
retrievalClientV2 = clientsv2.NewRetrievalClient(logger, chainReader, cs, socketStateCache, v, 10)

return ics.Start(context.Background())
return nil
}

var _ = AfterSuite(func() {
Expand Down
16 changes: 8 additions & 8 deletions retriever/cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import (
"github.com/Layr-Labs/eigenda/common/healthcheck"
"github.com/Layr-Labs/eigenda/core"
"github.com/Layr-Labs/eigenda/core/eth"
"github.com/Layr-Labs/eigenda/core/thegraph"
"github.com/Layr-Labs/eigenda/encoding/kzg/verifier"
"github.com/Layr-Labs/eigenda/retriever"
retrivereth "github.com/Layr-Labs/eigenda/retriever/eth"
Expand Down Expand Up @@ -99,19 +98,20 @@ func RetrieverMain(ctx *cli.Context) error {
if err != nil {
log.Fatalln("could not start tcp listener", err)
}

logger.Info("Connecting to subgraph", "url", config.ChainStateConfig.Endpoint)
ics := thegraph.MakeIndexedChainState(config.ChainStateConfig, cs, logger)
socketStateCache, err := eth.NewSocketStateCache(context.Background(), tx, logger)
if err != nil {
log.Fatalln("could not start tcp listener", err)
}

if config.EigenDAVersion == 1 {
agn := &core.StdAssignmentCoordinator{}
retrievalClient, err := clients.NewRetrievalClient(logger, ics, agn, nodeClient, v, config.NumConnections)
retrievalClient, err := clients.NewRetrievalClient(logger, cs, socketStateCache, agn, nodeClient, v, config.NumConnections)
if err != nil {
log.Fatalln("could not start tcp listener", err)
}

chainClient := retrivereth.NewChainClient(gethClient, logger)
retrieverServiceServer := retriever.NewServer(config, logger, retrievalClient, ics, chainClient)
retrieverServiceServer := retriever.NewServer(config, logger, retrievalClient, chainClient)
if err = retrieverServiceServer.Start(context.Background()); err != nil {
log.Fatalln("failed to start retriever service server", err)
}
Expand All @@ -131,8 +131,8 @@ func RetrieverMain(ctx *cli.Context) error {
}

if config.EigenDAVersion == 2 {
retrievalClient := clientsv2.NewRetrievalClient(logger, tx, ics, v, config.NumConnections)
retrieverServiceServer := retrieverv2.NewServer(config, logger, retrievalClient, ics)
retrievalClient := clientsv2.NewRetrievalClient(logger, tx, cs, socketStateCache, v, config.NumConnections)
retrieverServiceServer := retrieverv2.NewServer(config, logger, retrievalClient, cs)
if err = retrieverServiceServer.Start(context.Background()); err != nil {
log.Fatalln("failed to start retriever service server", err)
}
Expand Down
5 changes: 1 addition & 4 deletions retriever/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ type Server struct {
config *Config
retrievalClient clients.RetrievalClient
chainClient eth.ChainClient
indexedState core.IndexedChainState
logger logging.Logger
metrics *Metrics
}
Expand All @@ -28,7 +27,6 @@ func NewServer(
config *Config,
logger logging.Logger,
retrievalClient clients.RetrievalClient,
indexedState core.IndexedChainState,
chainClient eth.ChainClient,
) *Server {
metrics := NewMetrics(config.MetricsConfig.HTTPPort, logger)
Expand All @@ -37,15 +35,14 @@ func NewServer(
config: config,
retrievalClient: retrievalClient,
chainClient: chainClient,
indexedState: indexedState,
logger: logger.With("component", "RetrieverServer"),
metrics: metrics,
}
}

func (s *Server) Start(ctx context.Context) error {
s.metrics.Start(ctx)
return s.indexedState.Start(ctx)
return nil
}

func (s *Server) RetrieveBlob(ctx context.Context, req *pb.BlobRequest) (*pb.BlobReply, error) {
Expand Down
Loading

0 comments on commit 4607034

Please sign in to comment.