Skip to content

Commit

Permalink
Structify Opeartor Socket
Browse files Browse the repository at this point in the history
Changes:
	Defines a concrete type OperatorSocket
	Internal state is unexported to prevent unintented changes from the importing package.
	Defines Get functions to access methods
  • Loading branch information
supriya-premkumar committed Feb 25, 2025
1 parent 3398685 commit 4d165f9
Show file tree
Hide file tree
Showing 19 changed files with 189 additions and 112 deletions.
17 changes: 15 additions & 2 deletions api/clients/node_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,12 @@ func (c client) GetBlobHeader(
batchHeaderHash [32]byte,
blobIndex uint32,
) (*core.BlobHeader, *merkletree.Proof, error) {
operatorSocket, err := core.ParseOperatorSocket(socket)
if err != nil {
return nil, nil, err
}
conn, err := grpc.NewClient(
core.OperatorSocket(socket).GetV1RetrievalSocket(),
operatorSocket.GetV1RetrievalSocket(),
grpc.WithTransportCredentials(insecure.NewCredentials()),
)
if err != nil {
Expand Down Expand Up @@ -85,8 +89,17 @@ func (c client) GetChunks(
quorumID core.QuorumID,
chunksChan chan RetrievedChunks,
) {
operatorSocket, err := core.ParseOperatorSocket(opInfo.Socket)
if err != nil {
chunksChan <- RetrievedChunks{
OperatorID: opID,
Err: err,
Chunks: nil,
}
return
}
conn, err := grpc.NewClient(
core.OperatorSocket(opInfo.Socket).GetV1RetrievalSocket(),
operatorSocket.GetV1RetrievalSocket(),
grpc.WithTransportCredentials(insecure.NewCredentials()),
)
if err != nil {
Expand Down
11 changes: 10 additions & 1 deletion api/clients/v2/retrieval_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,17 @@ func (r *retrievalClient) getChunksFromOperator(
fudgeFactor := units.MiB // to allow for some overhead from things like protobuf encoding
maxMessageSize := maxBlobSize*encodingRate + fudgeFactor

operatorSocket, err := core.ParseOperatorSocket(opInfo.Socket)
if err != nil {
chunksChan <- clients.RetrievedChunks{
OperatorID: opID,
Err: err,
Chunks: nil,
}
return
}
conn, err := grpc.NewClient(
core.OperatorSocket(opInfo.Socket).GetV2RetrievalSocket(),
operatorSocket.GetV2RetrievalSocket(),
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(maxMessageSize)),
)
Expand Down
4 changes: 2 additions & 2 deletions core/mock/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,10 @@ func (d *ChainDataMock) GetTotalOperatorStateWithQuorums(ctx context.Context, bl
retrievalPort := fmt.Sprintf("3%03v", 2*i+1)
v2DispersalPort := fmt.Sprintf("3%03v", 2*i+2)
v2RetrievalPort := fmt.Sprintf("3%03v", 2*i+3)
socket := core.MakeOperatorSocket(host, dispersalPort, retrievalPort, v2DispersalPort, v2RetrievalPort)
socket := core.NewOperatorSocket(host, dispersalPort, retrievalPort, v2DispersalPort, v2RetrievalPort)

indexed := &core.IndexedOperatorInfo{
Socket: string(socket),
Socket: socket.String(),
PubkeyG1: d.KeyPairs[id].GetPubKeyG1(),
PubkeyG2: d.KeyPairs[id].GetPubKeyG2(),
}
Expand Down
24 changes: 12 additions & 12 deletions core/serialization.go
Original file line number Diff line number Diff line change
Expand Up @@ -528,33 +528,33 @@ func decode(data []byte, obj any) error {
}

func (s OperatorSocket) GetV1DispersalSocket() string {
ip, v1DispersalPort, _, _, _, err := ParseOperatorSocket(string(s))
if err != nil {
if s.host == "" || s.v1DispersalPort == "" {
return ""
}
return fmt.Sprintf("%s:%s", ip, v1DispersalPort)
return fmt.Sprintf("%s:%s", s.host, s.v1DispersalPort)
}

func (s OperatorSocket) GetV2DispersalSocket() string {
ip, _, _, v2DispersalPort, _, err := ParseOperatorSocket(string(s))
if err != nil || v2DispersalPort == "" {
if s.host == "" || s.v2DispersalPort == "" {
return ""
}
return fmt.Sprintf("%s:%s", ip, v2DispersalPort)
return fmt.Sprintf("%s:%s", s.host, s.v2DispersalPort)
}

func (s OperatorSocket) GetV1RetrievalSocket() string {
ip, _, v1retrievalPort, _, _, err := ParseOperatorSocket(string(s))
if err != nil {
if s.host == "" || s.v1RetrievalPort == "" {
return ""
}
return fmt.Sprintf("%s:%s", ip, v1retrievalPort)
return fmt.Sprintf("%s:%s", s.host, s.v1RetrievalPort)
}

func (s OperatorSocket) GetV2RetrievalSocket() string {
ip, _, _, _, v2RetrievalPort, err := ParseOperatorSocket(string(s))
if err != nil || v2RetrievalPort == "" {
if s.host == "" || s.v2RetrievalPort == "" {
return ""
}
return fmt.Sprintf("%s:%s", ip, v2RetrievalPort)
return fmt.Sprintf("%s:%s", s.host, s.v2RetrievalPort)
}

func (s OperatorSocket) GetHost() string {
return s.host
}
96 changes: 60 additions & 36 deletions core/serialization_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,123 +195,147 @@ func TestHashPubKeyG1(t *testing.T) {
}

func TestParseOperatorSocket(t *testing.T) {
operatorSocket := "localhost:1234;5678;9999;10001"
host, v1DispersalPort, v1RetrievalPort, v2DispersalPort, v2RetrievalPort, err := core.ParseOperatorSocket(operatorSocket)
opSocketStr := "localhost:1234;5678;9999;10001"
operatorSocket, err := core.ParseOperatorSocket(opSocketStr)
assert.NoError(t, err)
assert.Equal(t, "localhost", host)
assert.Equal(t, "1234", v1DispersalPort)
assert.Equal(t, "5678", v1RetrievalPort)
assert.Equal(t, "9999", v2DispersalPort)
assert.Equal(t, "10001", v2RetrievalPort)

host, v1DispersalPort, v1RetrievalPort, v2DispersalPort, _, err = core.ParseOperatorSocket("localhost:1234;5678")
assert.Equal(t, "localhost", operatorSocket.GetHost())
assert.Equal(t, "localhost:1234", operatorSocket.GetV1DispersalSocket())
assert.Equal(t, "localhost:5678", operatorSocket.GetV1RetrievalSocket())
assert.Equal(t, "localhost:9999", operatorSocket.GetV2DispersalSocket())
assert.Equal(t, "localhost:10001", operatorSocket.GetV2RetrievalSocket())

opSocketStr = "localhost:1234;5678"
operatorSocket, err = core.ParseOperatorSocket(opSocketStr)
assert.NoError(t, err)
assert.Equal(t, "localhost", host)
assert.Equal(t, "1234", v1DispersalPort)
assert.Equal(t, "5678", v1RetrievalPort)
assert.Equal(t, "", v2DispersalPort)
assert.Equal(t, "localhost", operatorSocket.GetHost())
assert.Equal(t, "localhost:1234", operatorSocket.GetV1DispersalSocket())
assert.Equal(t, "localhost:5678", operatorSocket.GetV1RetrievalSocket())
assert.Equal(t, "", operatorSocket.GetV2DispersalSocket())

_, _, _, _, _, err = core.ParseOperatorSocket("localhost;1234;5678")
opSocketStr = "localhost;1234;5678"
_, err = core.ParseOperatorSocket(opSocketStr)
assert.NotNil(t, err)
assert.ErrorContains(t, err, "invalid host address format")

_, _, _, _, _, err = core.ParseOperatorSocket("localhost:12345678")
opSocketStr = "localhost:12345678"
_, err = core.ParseOperatorSocket(opSocketStr)
assert.NotNil(t, err)
assert.ErrorContains(t, err, "invalid v1 dispersal port format")

_, _, _, _, _, err = core.ParseOperatorSocket("localhost1234;5678")
opSocketStr = "localhost1234;5678"
_, err = core.ParseOperatorSocket(opSocketStr)
assert.NotNil(t, err)
assert.ErrorContains(t, err, "invalid host address format")
}

func TestGetV1DispersalSocket(t *testing.T) {
operatorSocket := core.OperatorSocket("localhost:1234;5678;9999;1025")
operatorSocket, err := core.ParseOperatorSocket("localhost:1234;5678;9999;1025")
socket := operatorSocket.GetV1DispersalSocket()
assert.NoError(t, err)
assert.Equal(t, "localhost:1234", socket)

operatorSocket = core.OperatorSocket("localhost:1234;5678")
operatorSocket, err = core.ParseOperatorSocket("localhost:1234;5678")
socket = operatorSocket.GetV1DispersalSocket()
assert.NoError(t, err)
assert.Equal(t, "localhost:1234", socket)

operatorSocket = core.OperatorSocket("localhost:1234;5678;")
operatorSocket, err = core.ParseOperatorSocket("localhost:1234;5678;")
socket = operatorSocket.GetV1DispersalSocket()
assert.NotNil(t, err)
assert.Equal(t, "", socket)

operatorSocket = core.OperatorSocket("localhost:1234")
operatorSocket, err = core.ParseOperatorSocket("localhost:1234")
assert.NotNil(t, err)
socket = operatorSocket.GetV1DispersalSocket()
assert.Equal(t, "", socket)
}

func TestGetV1RetrievalSocket(t *testing.T) {
// Valid v1/v2 socket
operatorSocket := core.OperatorSocket("localhost:1234;5678;9999;10001")
operatorSocket, err := core.ParseOperatorSocket("localhost:1234;5678;9999;10001")
assert.NoError(t, err)
socket := operatorSocket.GetV1RetrievalSocket()
assert.Equal(t, "localhost:5678", socket)

// Valid v1 socket
operatorSocket = core.OperatorSocket("localhost:1234;5678")
operatorSocket, err = core.ParseOperatorSocket("localhost:1234;5678")
socket = operatorSocket.GetV1RetrievalSocket()
assert.NoError(t, err)
assert.Equal(t, "localhost:5678", socket)

// Invalid socket testcases
operatorSocket = core.OperatorSocket("localhost:1234;5678;9999;10001;")
operatorSocket, err = core.ParseOperatorSocket("localhost:1234;5678;9999;10001;")
assert.NotNil(t, err)
socket = operatorSocket.GetV1RetrievalSocket()
assert.Equal(t, "", socket)

operatorSocket = core.OperatorSocket("localhost:1234;5678;")
operatorSocket, err = core.ParseOperatorSocket("localhost:1234;5678;")
assert.NotNil(t, err)
socket = operatorSocket.GetV1RetrievalSocket()
assert.Equal(t, "", socket)

operatorSocket = core.OperatorSocket("localhost:;1234;5678;")
operatorSocket, err = core.ParseOperatorSocket("localhost:;1234;5678;")
assert.NotNil(t, err)
socket = operatorSocket.GetV1RetrievalSocket()
assert.Equal(t, "", socket)

operatorSocket = core.OperatorSocket("localhost:1234;:;5678;")
operatorSocket, err = core.ParseOperatorSocket("localhost:1234;:;5678;")
assert.NotNil(t, err)
socket = operatorSocket.GetV1RetrievalSocket()
assert.Equal(t, "", socket)

operatorSocket = core.OperatorSocket("localhost:;;;")
operatorSocket, err = core.ParseOperatorSocket("localhost:;;;")
assert.NotNil(t, err)
socket = operatorSocket.GetV1RetrievalSocket()
assert.Equal(t, "", socket)

operatorSocket = core.OperatorSocket("localhost:1234")
operatorSocket, err = core.ParseOperatorSocket("localhost:1234")
assert.NotNil(t, err)
socket = operatorSocket.GetV1RetrievalSocket()
assert.Equal(t, "", socket)
}

func TestGetV2RetrievalSocket(t *testing.T) {
// Valid v1/v2 socket
operatorSocket := core.OperatorSocket("localhost:1234;5678;9999;10001")
operatorSocket, err := core.ParseOperatorSocket("localhost:1234;5678;9999;10001")
assert.NoError(t, err)
socket := operatorSocket.GetV2RetrievalSocket()
assert.Equal(t, "localhost:10001", socket)

// Invalid v2 socket
operatorSocket = core.OperatorSocket("localhost:1234;5678")
operatorSocket, err = core.ParseOperatorSocket("localhost:1234;5678")
assert.NoError(t, err)
socket = operatorSocket.GetV2RetrievalSocket()
assert.Equal(t, "", socket)

// Invalid socket testcases
operatorSocket = core.OperatorSocket("localhost:1234;5678;9999;10001;")
operatorSocket, err = core.ParseOperatorSocket("localhost:1234;5678;9999;10001;")
assert.NotNil(t, err)
socket = operatorSocket.GetV2RetrievalSocket()
assert.Equal(t, "", socket)

operatorSocket = core.OperatorSocket("localhost:1234;5678;")
operatorSocket, err = core.ParseOperatorSocket("localhost:1234;5678;")
assert.NotNil(t, err)
socket = operatorSocket.GetV2RetrievalSocket()
assert.Equal(t, "", socket)

operatorSocket = core.OperatorSocket("localhost:;1234;5678;")
operatorSocket, err = core.ParseOperatorSocket("localhost:;1234;5678;")
assert.NotNil(t, err)
socket = operatorSocket.GetV2RetrievalSocket()
assert.Equal(t, "", socket)

operatorSocket = core.OperatorSocket("localhost:1234;:;5678;")
operatorSocket, err = core.ParseOperatorSocket("localhost:1234;:;5678;")
assert.NotNil(t, err)
socket = operatorSocket.GetV2RetrievalSocket()
assert.Equal(t, "", socket)

operatorSocket = core.OperatorSocket("localhost:;;;")
operatorSocket, err = core.ParseOperatorSocket("localhost:;;;")
assert.NotNil(t, err)
socket = operatorSocket.GetV2RetrievalSocket()
assert.Equal(t, "", socket)

operatorSocket = core.OperatorSocket("localhost:1234")
operatorSocket, err = core.ParseOperatorSocket("localhost:1234")
assert.NotNil(t, err)
socket = operatorSocket.GetV2RetrievalSocket()
assert.Equal(t, "", socket)
}
Expand Down
Loading

0 comments on commit 4d165f9

Please sign in to comment.