Skip to content

Commit

Permalink
feat(go/adbc/driver/flightsql): enable incremental queries
Browse files Browse the repository at this point in the history
Closes apache#1451.
  • Loading branch information
lidavidm committed Jan 12, 2024
1 parent 6b73e52 commit 0351559
Show file tree
Hide file tree
Showing 36 changed files with 549 additions and 116 deletions.
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ $ cd go/adbc && go-licenses report ./... \
--ignore github.com/apache/arrow/go/v11 \
--ignore github.com/apache/arrow/go/v12 \
--ignore github.com/apache/arrow/go/v13 \
--ignore github.com/apache/arrow/go/v14 \
--ignore github.com/apache/arrow/go/v15 \
--template ../../license.tpl > ../../LICENSE.txt 2> /dev/null
```

Expand Down
4 changes: 2 additions & 2 deletions go/adbc/adbc.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ import (
"context"
"fmt"

"github.com/apache/arrow/go/v14/arrow"
"github.com/apache/arrow/go/v14/arrow/array"
"github.com/apache/arrow/go/v15/arrow"
"github.com/apache/arrow/go/v15/arrow/array"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
)
Expand Down
2 changes: 1 addition & 1 deletion go/adbc/driver/driverbase/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import (
"context"

"github.com/apache/arrow-adbc/go/adbc"
"github.com/apache/arrow/go/v14/arrow/memory"
"github.com/apache/arrow/go/v15/arrow/memory"
"golang.org/x/exp/slog"
)

Expand Down
2 changes: 1 addition & 1 deletion go/adbc/driver/driverbase/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ package driverbase

import (
"github.com/apache/arrow-adbc/go/adbc"
"github.com/apache/arrow/go/v14/arrow/memory"
"github.com/apache/arrow/go/v15/arrow/memory"
)

// DriverImpl is an interface that drivers implement to provide
Expand Down
10 changes: 5 additions & 5 deletions go/adbc/driver/flightsql/cmd/testserver/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ import (
"strconv"
"strings"

"github.com/apache/arrow/go/v14/arrow"
"github.com/apache/arrow/go/v14/arrow/array"
"github.com/apache/arrow/go/v14/arrow/flight"
"github.com/apache/arrow/go/v14/arrow/flight/flightsql"
"github.com/apache/arrow/go/v14/arrow/memory"
"github.com/apache/arrow/go/v15/arrow"
"github.com/apache/arrow/go/v15/arrow/array"
"github.com/apache/arrow/go/v15/arrow/flight"
"github.com/apache/arrow/go/v15/arrow/flight/flightsql"
"github.com/apache/arrow/go/v15/arrow/memory"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
Expand Down
302 changes: 296 additions & 6 deletions go/adbc/driver/flightsql/flightsql_adbc_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,20 @@ import (
"net/textproto"
"os"
"strings"
"strconv"
"testing"
"time"
"sync"
"github.com/google/uuid"

"github.com/apache/arrow-adbc/go/adbc"
driver "github.com/apache/arrow-adbc/go/adbc/driver/flightsql"
"github.com/apache/arrow/go/v14/arrow"
"github.com/apache/arrow/go/v14/arrow/array"
"github.com/apache/arrow/go/v14/arrow/flight"
"github.com/apache/arrow/go/v14/arrow/flight/flightsql"
"github.com/apache/arrow/go/v14/arrow/flight/flightsql/schema_ref"
"github.com/apache/arrow/go/v14/arrow/memory"
"github.com/apache/arrow/go/v15/arrow"
"github.com/apache/arrow/go/v15/arrow/array"
"github.com/apache/arrow/go/v15/arrow/flight"
"github.com/apache/arrow/go/v15/arrow/flight/flightsql"
"github.com/apache/arrow/go/v15/arrow/flight/flightsql/schema_ref"
"github.com/apache/arrow/go/v15/arrow/memory"
"github.com/golang/protobuf/ptypes/wrappers"
"github.com/stretchr/testify/suite"
"golang.org/x/exp/maps"
Expand Down Expand Up @@ -108,6 +111,10 @@ func TestExecuteSchema(t *testing.T) {
suite.Run(t, &ExecuteSchemaTests{})
}

func TestIncrementalPoll(t *testing.T) {
suite.Run(t, &IncrementalPollTests{})
}

func TestTimeout(t *testing.T) {
suite.Run(t, &TimeoutTests{})
}
Expand Down Expand Up @@ -427,6 +434,289 @@ func (ts *ExecuteSchemaTests) TestQuery() {
ts.True(expectedSchema.Equal(schema), schema.String())
}

// ---- IncrementalPoll Tests --------------------

type IncrementalQuery struct {
query string
nextIndex int
}

type IncrementalPollTestServer struct {
flightsql.BaseServer
mu sync.Mutex
queries map[string]*IncrementalQuery
testCases map[string]IncrementalPollTestCase
}

func (srv *IncrementalPollTestServer) PollFlightInfo(ctx context.Context, desc *flight.FlightDescriptor) (*flight.PollInfo, error) {
srv.mu.Lock()
defer srv.mu.Unlock()

var val wrapperspb.StringValue
var err error
if err = proto.Unmarshal(desc.Cmd, &val); err != nil {
return nil, err
}
queryId := val.Value
progress := int64(0)
if strings.Contains(queryId, ";") {
parts := strings.SplitN(queryId, ";", 2)
queryId = parts[0]
progress, err = strconv.ParseInt(parts[1], 10, 32)
if err != nil {
return nil, err
}
}

query, ok := srv.queries[queryId]
if !ok {
return nil, status.Errorf(codes.NotFound, "Query ID not found")
}

testCase, ok := srv.testCases[query.query]
if !ok {
return nil, status.Errorf(codes.Unimplemented, fmt.Sprintf("Invalid case %s", query.query))
}

if testCase.differentRetryDescriptor && progress != int64(query.nextIndex) {
return nil, status.Errorf(codes.InvalidArgument, fmt.Sprintf("Used wrong retry descriptor, expected %d but got %d", query.nextIndex, progress))
}

return srv.MakePollInfo(&testCase, query, queryId)
}

func (srv *IncrementalPollTestServer) PollFlightInfoStatement(ctx context.Context, query flightsql.StatementQuery, desc *flight.FlightDescriptor) (*flight.PollInfo, error) {
queryId := uuid.New().String()

testCase, ok := srv.testCases[query.GetQuery()]
if !ok {
return nil, status.Errorf(codes.Unimplemented, fmt.Sprintf("Invalid case %s", query.GetQuery()))
}

srv.mu.Lock()
defer srv.mu.Unlock()

srv.queries[queryId] = &IncrementalQuery{
query: query.GetQuery(),
nextIndex: 0,
}

return srv.MakePollInfo(&testCase, srv.queries[queryId], queryId)
}

func (srv *IncrementalPollTestServer) MakePollInfo(testCase *IncrementalPollTestCase, query *IncrementalQuery, queryId string) (*flight.PollInfo, error) {
schema := flight.SerializeSchema(arrow.NewSchema([]arrow.Field{
{Name: "ints", Type: arrow.PrimitiveTypes.Int32},
}, nil), srv.Alloc)

pb := wrapperspb.StringValue{Value: queryId}
if testCase.differentRetryDescriptor {
pb.Value = queryId + ";" + strconv.Itoa(query.nextIndex + 1)
}
descriptor, err := proto.Marshal(&pb)
if err != nil {
return nil, err
}

numEndpoints := 0
for i := 0; i <= query.nextIndex; i++ {
if i >= len(testCase.progress) {
break
}
numEndpoints += testCase.progress[i]
}
endpoints := make([]*flight.FlightEndpoint, numEndpoints)
for i := range endpoints {
endpoints[i] = &flight.FlightEndpoint{
Ticket: &flight.Ticket{
Ticket: []byte{},
},
}
}

query.nextIndex++
pollInfo := flight.PollInfo{
Info: &flight.FlightInfo {
Schema: schema,
Endpoint: endpoints,
},
FlightDescriptor: &flight.FlightDescriptor{
Type: flight.DescriptorCMD,
Cmd: descriptor,
},
Progress: proto.Float64(float64(query.nextIndex) / float64(len(testCase.progress))),
}

if query.nextIndex >= len(testCase.progress) {
if testCase.completeLazily {
if query.nextIndex == len(testCase.progress) {
// Make the client poll one more time
} else {
pollInfo.FlightDescriptor = nil
delete(srv.queries, queryId)
}

} else {
pollInfo.FlightDescriptor = nil
delete(srv.queries, queryId)
}
}

fmt.Printf("Returning %d endpoints, has retry? %t\n", numEndpoints, pollInfo.FlightDescriptor != nil)

return &pollInfo, nil
}

type IncrementalPollTestCase struct {
// on each poll (including the first), this many new endpoints complete
// making 0 progress is allowed, but not recommended (allow clients to 'long poll')
progress []int

// use a different retry descriptor for each poll
differentRetryDescriptor bool

// require one extra poll to get completion (i.e. the last poll will have a nil FlightInfo)
completeLazily bool
}

type IncrementalPollTests struct {
ServerBasedTests
testCases map[string]IncrementalPollTestCase
}

func (suite *IncrementalPollTests) SetupSuite() {
suite.testCases = map[string]IncrementalPollTestCase{
"basic": {
progress: []int{1, 1, 1, 1},
},
"basic 2": {
progress: []int{2, 3, 4, 5},
},
"basic 3": {
progress: []int{2},
},
"descriptor changes": {
progress: []int{1, 1, 1, 1},
differentRetryDescriptor: true,
},
"lazy": {
progress: []int{1, 1, 1, 1},
completeLazily: true,
},
"lazy 2": {
progress: []int{1, 1, 1, 0},
completeLazily: true,
},
"no progress": {
progress: []int{0, 1, 1, 1},
},
"no progress 2": {
progress: []int{0, 0, 1, 1},
},
"no progress 3": {
progress: []int{0, 0, 1, 0},
},
}

srv := IncrementalPollTestServer{
queries: make(map[string]*IncrementalQuery),
testCases: suite.testCases,
}
srv.Alloc = memory.DefaultAllocator
suite.DoSetupSuite(&srv, nil, nil)
}

func (ts *IncrementalPollTests) TestMaxProgress() {
stmt, err := ts.cnxn.NewStatement()
ts.NoError(err)
defer stmt.Close()
opts := stmt.(adbc.GetSetOptions)

val, err := opts.GetOptionDouble(adbc.OptionKeyMaxProgress)
ts.NoError(err)
ts.Equal(1.0, val)
}

func (ts *IncrementalPollTests) TestOptionValue() {
stmt, err := ts.cnxn.NewStatement()
ts.NoError(err)
defer stmt.Close()
opts := stmt.(adbc.GetSetOptions)

val, err := opts.GetOption(adbc.OptionKeyIncremental)
ts.NoError(err)
ts.Equal(adbc.OptionValueDisabled, val)

ts.NoError(stmt.SetOption(adbc.OptionKeyIncremental, adbc.OptionValueEnabled))

val, err = opts.GetOption(adbc.OptionKeyIncremental)
ts.NoError(err)
ts.Equal(adbc.OptionValueEnabled, val)

var adbcErr adbc.Error
ts.ErrorAs(stmt.SetOption(adbc.OptionKeyIncremental, "foobar"), &adbcErr)
ts.Equal(adbc.StatusInvalidArgument, adbcErr.Code)
}

func (ts *IncrementalPollTests) RunOneTestCase(ctx context.Context, stmt adbc.Statement, name string, testCase *IncrementalPollTestCase) {
opts := stmt.(adbc.GetSetOptions)
ts.NoError(stmt.SetSqlQuery(name))

for idx, progress := range testCase.progress {
fmt.Printf("Poll %d/%d\n", idx + 1, len(testCase.progress))
if progress == 0 {
// the driver hides this from us
continue
}

_, partitions, _, err := stmt.ExecutePartitions(ctx)
ts.NoError(err)

ts.Equal(uint64(progress), partitions.NumPartitions)

val, err := opts.GetOptionDouble(adbc.OptionKeyProgress)
ts.NoError(err)
ts.Equal(float64(idx + 1) / float64(len(testCase.progress)), val)
}

fmt.Println("Poll last")
// Query completed, but we find out by getting no partitions in this call
_, partitions, _, err := stmt.ExecutePartitions(ctx)
ts.NoError(err)

ts.Equal(uint64(0), partitions.NumPartitions)
}

func (ts *IncrementalPollTests) TestQuery() {
ctx := context.Background()
for name, testCase := range ts.testCases {
ts.Run(name, func() {
stmt, err := ts.cnxn.NewStatement()
ts.NoError(err)
defer stmt.Close()

ts.NoError(stmt.SetOption(adbc.OptionKeyIncremental, adbc.OptionValueEnabled))

// Run the query multiple times (we should be able to reuse the statement)
for i := 0; i < 2; i++ {
ts.RunOneTestCase(ctx, stmt, name, &testCase)
}
})
}
}

func (ts *IncrementalPollTests) TestQueryPrepared() {
}

func (ts *IncrementalPollTests) TestQueryPreparedTransaction() {
}

func (ts *IncrementalPollTests) TestQueryTransaction() {
}

func (ts *IncrementalPollTests) TestStatementReuse() {
}

// ---- Timeout Tests --------------------

type TimeoutTestServer struct {
Expand Down
Loading

0 comments on commit 0351559

Please sign in to comment.