diff --git a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go index 3df7c9a33c..3d7b8ce8cc 100644 --- a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go +++ b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go @@ -23,6 +23,7 @@ import ( "context" "errors" "fmt" + "net" "net/textproto" "os" "strconv" @@ -810,11 +811,18 @@ func (ts *IncrementalPollTests) TestQueryTransaction() { type TimeoutTestServer struct { flightsql.BaseServer + badPort int + goodPort int } func (ts *TimeoutTestServer) DoGetStatement(ctx context.Context, tkt flightsql.StatementQueryTicket) (*arrow.Schema, <-chan flight.StreamChunk, error) { - if string(tkt.GetStatementHandle()) == "sleep and succeed" { + ticket := string(tkt.GetStatementHandle()) + if ticket == "sleep and succeed" { time.Sleep(1 * time.Second) + } + + switch ticket { + case "bad endpoint", "sleep and succeed": sc := arrow.NewSchema([]arrow.Field{{Name: "a", Type: arrow.PrimitiveTypes.Int32, Nullable: true}}, nil) rec, _, err := array.RecordFromJSON(memory.DefaultAllocator, sc, strings.NewReader(`[{"a": 5}]`)) if err != nil { @@ -850,6 +858,23 @@ func (ts *TimeoutTestServer) GetFlightInfoStatement(ctx context.Context, cmd fli switch cmd.GetQuery() { case "timeout": <-ctx.Done() + case "bad endpoint": + tkt, _ := flightsql.CreateStatementQueryTicket([]byte("bad endpoint")) + info := &flight.FlightInfo{ + FlightDescriptor: desc, + Endpoint: []*flight.FlightEndpoint{ + { + Ticket: &flight.Ticket{Ticket: tkt}, + Location: []*flight.Location{ + {Uri: fmt.Sprintf("grpc://localhost:%d", ts.badPort)}, + {Uri: fmt.Sprintf("grpc://localhost:%d", ts.goodPort)}, + }, + }, + }, + TotalRecords: -1, + TotalBytes: -1, + } + return info, nil case "fetch": tkt, _ := flightsql.CreateStatementQueryTicket([]byte("fetch")) info := &flight.FlightInfo{ @@ -884,10 +909,23 @@ func (ts *TimeoutTestServer) CreatePreparedStatement(ctx context.Context, req fl type TimeoutTests struct { ServerBasedTests + server net.Listener } func (suite *TimeoutTests) SetupSuite() { - suite.DoSetupSuite(&TimeoutTestServer{}, nil, nil) + var err error + suite.server, err = net.Listen("tcp", "localhost:0") + suite.NoError(err) + + badPort := suite.server.Addr().(*net.TCPAddr).Port + server := &TimeoutTestServer{badPort: badPort} + suite.DoSetupSuite(server, nil, nil) + server.goodPort = suite.s.Addr().(*net.TCPAddr).Port +} + +func (suite *TimeoutTests) TearDownSuite() { + suite.ServerBasedTests.TearDownSuite() + suite.NoError(suite.server.Close()) } func (ts *TimeoutTests) TestInvalidValues() { @@ -1075,6 +1113,26 @@ func (ts *TimeoutTests) TestDontTimeout() { ts.Truef(array.RecordEqual(rec, expected), "expected: %s\nactual: %s", expected, rec) } +func (ts *TimeoutTests) TestBadAddress() { + stmt, err := ts.cnxn.NewStatement() + ts.Require().NoError(err) + defer stmt.Close() + ts.Require().NoError(stmt.SetSqlQuery("bad endpoint")) + // XXX: this first attempt takes about 20 seconds, presumably due to + // some setting in grpc-go, but there's no obvious knob to tweak it + rr, _, err := stmt.ExecuteQuery(context.Background()) + ts.Require().NoError(err) + defer rr.Release() + + rr, _, err = stmt.ExecuteQuery(context.Background()) + ts.Require().NoError(err) + defer rr.Release() + + rr, _, err = stmt.ExecuteQuery(context.Background()) + ts.Require().NoError(err) + defer rr.Release() +} + // ---- Cookie Tests -------------------- type CookieTestServer struct { flightsql.BaseServer diff --git a/go/adbc/driver/flightsql/flightsql_database.go b/go/adbc/driver/flightsql/flightsql_database.go index 1407fedf18..22f35b60f2 100644 --- a/go/adbc/driver/flightsql/flightsql_database.go +++ b/go/adbc/driver/flightsql/flightsql_database.go @@ -368,6 +368,7 @@ func getFlightClient(ctx context.Context, loc string, d *databaseImpl, authMiddl } dialOpts := append(d.dialOpts.opts, grpc.WithTransportCredentials(creds)) + d.Logger.DebugContext(ctx, "new client", "location", loc) cl, err := flightsql.NewClient(target, nil, middleware, dialOpts...) if err != nil { return nil, adbc.Error{ @@ -395,7 +396,6 @@ func getFlightClient(ctx context.Context, loc string, d *databaseImpl, authMiddl } } - d.Logger.DebugContext(ctx, "new client", "location", loc) return cl, nil }