From bcbc1614c26ce4dcc35840b604773648502146bc Mon Sep 17 00:00:00 2001 From: David Li Date: Fri, 1 Mar 2024 17:07:40 -0500 Subject: [PATCH] feat(go/adbc/driver/flightsql): reflect gRPC status in vendor code (#1577) Does not work in C/C++/Python due to #1576. Fixes #1574. --- .../flightsql/flightsql_adbc_server_test.go | 63 +++++++++++++++++++ go/adbc/driver/flightsql/utils.go | 34 +++++++--- python/adbc_driver_flightsql/pyproject.toml | 3 + .../tests/test_errors.py | 15 +++++ python/adbc_driver_manager/pyproject.toml | 1 + python/adbc_driver_postgresql/pyproject.toml | 1 + python/adbc_driver_snowflake/pyproject.toml | 3 + python/adbc_driver_sqlite/pyproject.toml | 3 + 8 files changed, 116 insertions(+), 7 deletions(-) diff --git a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go index f779e6aff4..66e94da44e 100644 --- a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go +++ b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go @@ -253,6 +253,16 @@ func (srv *ErrorDetailsTestServer) GetFlightInfoStatement(ctx context.Context, q panic(err) } return &flight.FlightInfo{Endpoint: []*flight.FlightEndpoint{{Ticket: &flight.Ticket{Ticket: tkt}}}}, nil + } else if query.GetQuery() == "vendorcode" { + return nil, status.Errorf(codes.ResourceExhausted, "Resource exhausted") + } else if query.GetQuery() == "binaryheader" { + if err := grpc.SendHeader(ctx, metadata.Pairs("x-header-bin", string([]byte{0, 110}))); err != nil { + return nil, err + } + if err := grpc.SetTrailer(ctx, metadata.Pairs("x-trailer-bin", string([]byte{111, 0, 112}))); err != nil { + return nil, err + } + return nil, status.Errorf(codes.FailedPrecondition, "Resource exhausted") } return nil, status.Errorf(codes.Unimplemented, "GetSchemaStatement not implemented") } @@ -287,6 +297,43 @@ func (suite *ErrorDetailsTests) SetupSuite() { suite.DoSetupSuite(&srv, nil, nil) } +func (ts *ErrorDetailsTests) TestBinaryDetails() { + stmt, err := ts.cnxn.NewStatement() + ts.NoError(err) + defer stmt.Close() + + ts.NoError(stmt.SetSqlQuery("binaryheader")) + + _, _, err = stmt.ExecuteQuery(context.Background()) + var adbcErr adbc.Error + ts.ErrorAs(err, &adbcErr) + + ts.Equal(int32(codes.FailedPrecondition), adbcErr.VendorCode) + + ts.Equal(2, len(adbcErr.Details)) + + headerFound := false + trailerFound := false + for _, wrapper := range adbcErr.Details { + switch wrapper.Key() { + case "x-header-bin": + val, err := wrapper.Serialize() + ts.NoError(err) + ts.Equal([]byte{0, 110}, val) + headerFound = true + case "x-trailer-bin": + val, err := wrapper.Serialize() + ts.NoError(err) + ts.Equal([]byte{111, 0, 112}, val) + trailerFound = true + default: + ts.Failf("Unexpected detail key: %s", wrapper.Key()) + } + } + ts.Truef(headerFound, "Did not find x-header-bin") + ts.Truef(trailerFound, "Did not find x-trailer-bin") +} + func (ts *ErrorDetailsTests) TestGetFlightInfo() { stmt, err := ts.cnxn.NewStatement() ts.NoError(err) @@ -298,6 +345,8 @@ func (ts *ErrorDetailsTests) TestGetFlightInfo() { var adbcErr adbc.Error ts.ErrorAs(err, &adbcErr) + ts.Equal(int32(codes.Unknown), adbcErr.VendorCode) + ts.Equal(1, len(adbcErr.Details)) wrapper := adbcErr.Details[0] @@ -347,6 +396,20 @@ func (ts *ErrorDetailsTests) TestDoGet() { ts.Equal(int32(42), message.Value) } +func (ts *ErrorDetailsTests) TestVendorCode() { + stmt, err := ts.cnxn.NewStatement() + ts.NoError(err) + defer stmt.Close() + + ts.NoError(stmt.SetSqlQuery("vendorcode")) + + _, _, err = stmt.ExecuteQuery(context.Background()) + var adbcErr adbc.Error + ts.ErrorAs(err, &adbcErr) + + ts.Equal(int32(codes.ResourceExhausted), adbcErr.VendorCode) +} + // ---- ExecuteSchema Tests -------------------- type ExecuteSchemaTestServer struct { diff --git a/go/adbc/driver/flightsql/utils.go b/go/adbc/driver/flightsql/utils.go index fef8e73829..ea063ad9d8 100644 --- a/go/adbc/driver/flightsql/utils.go +++ b/go/adbc/driver/flightsql/utils.go @@ -20,6 +20,7 @@ package flightsql import ( "context" "fmt" + "strings" "github.com/apache/arrow-adbc/go/adbc" "google.golang.org/grpc/codes" @@ -92,9 +93,18 @@ func adbcFromFlightStatusWithDetails(err error, header, trailer metadata.MD, con // XXX: must check both headers and trailers because some implementations // (like gRPC-Java) will consolidate trailers into headers for failed RPCs for key, values := range header { - switch key { - case "content-type", "grpc-status-details-bin": + switch { + case key == "content-type": + // Not useful info continue + case key == "grpc-status-details-bin": + // gRPC library parses this above via grpcStatus.Proto() + continue + case strings.HasSuffix(key, "-bin"): + for _, value := range values { + // that's right, gRPC stuffs binary data into a "string" + details = append(details, &adbc.BinaryErrorDetail{Name: key, Detail: []byte(value)}) + } default: for _, value := range values { details = append(details, &adbc.TextErrorDetail{Name: key, Detail: value}) @@ -102,9 +112,18 @@ func adbcFromFlightStatusWithDetails(err error, header, trailer metadata.MD, con } } for key, values := range trailer { - switch key { - case "content-type", "grpc-status-details-bin": + switch { + case key == "content-type": + // Not useful info continue + case key == "grpc-status-details-bin": + // gRPC library parses this above via grpcStatus.Proto() + continue + case strings.HasSuffix(key, "-bin"): + for _, value := range values { + // that's right, gRPC stuffs binary data into a "string" + details = append(details, &adbc.BinaryErrorDetail{Name: key, Detail: []byte(value)}) + } default: for _, value := range values { details = append(details, &adbc.TextErrorDetail{Name: key, Detail: value}) @@ -114,9 +133,10 @@ func adbcFromFlightStatusWithDetails(err error, header, trailer metadata.MD, con return adbc.Error{ // People don't read error messages, so backload the context and frontload the server error - Msg: fmt.Sprintf("[FlightSQL] %s (%s; %s)", grpcStatus.Message(), grpcStatus.Code(), fmt.Sprintf(context, args...)), - Code: adbcCode, - Details: details, + Msg: fmt.Sprintf("[FlightSQL] %s (%s; %s)", grpcStatus.Message(), grpcStatus.Code(), fmt.Sprintf(context, args...)), + Code: adbcCode, + VendorCode: int32(grpcStatus.Code()), + Details: details, } } diff --git a/python/adbc_driver_flightsql/pyproject.toml b/python/adbc_driver_flightsql/pyproject.toml index ba3485f9f5..35362655ee 100644 --- a/python/adbc_driver_flightsql/pyproject.toml +++ b/python/adbc_driver_flightsql/pyproject.toml @@ -44,3 +44,6 @@ include-package-data = true license-files = ["LICENSE.txt", "NOTICE.txt"] packages = ["adbc_driver_flightsql"] py-modules = ["adbc_driver_flightsql"] + +[tool.pytest.ini_options] +xfail_strict = true diff --git a/python/adbc_driver_flightsql/tests/test_errors.py b/python/adbc_driver_flightsql/tests/test_errors.py index ee2b62d3ee..bed1878fc7 100644 --- a/python/adbc_driver_flightsql/tests/test_errors.py +++ b/python/adbc_driver_flightsql/tests/test_errors.py @@ -84,6 +84,21 @@ def test_query_error_fetch(test_dbapi): assert_detail(excval.value) +@pytest.mark.xfail(reason="apache/arrow-adbc#1576") +def test_query_error_vendor_code(test_dbapi): + with test_dbapi.cursor() as cur: + cur.execute("error_do_get") + with pytest.raises( + test_dbapi.ProgrammingError, + match=re.escape("INVALID_ARGUMENT: [FlightSQL] expected error (DoGet)"), + ) as excval: + cur.fetch_arrow_table() + + # TODO(https://github.com/apache/arrow-adbc/issues/1576): vendor code + # is gRPC status code; 3 is gRPC INVALID_ARGUMENT + assert excval.value.vendor_code == 3 + + def test_query_error_stream(test_dbapi): with test_dbapi.cursor() as cur: cur.execute("error_do_get_stream") diff --git a/python/adbc_driver_manager/pyproject.toml b/python/adbc_driver_manager/pyproject.toml index d2db1f102f..16f7d00eed 100644 --- a/python/adbc_driver_manager/pyproject.toml +++ b/python/adbc_driver_manager/pyproject.toml @@ -43,6 +43,7 @@ markers = [ "panicdummy: tests that require the testing-only panicdummy driver", "sqlite: tests that require the SQLite driver", ] +xfail_strict = true [tool.setuptools] license-files = ["LICENSE.txt", "NOTICE.txt"] diff --git a/python/adbc_driver_postgresql/pyproject.toml b/python/adbc_driver_postgresql/pyproject.toml index a5acc22ea4..298c8b0cb5 100644 --- a/python/adbc_driver_postgresql/pyproject.toml +++ b/python/adbc_driver_postgresql/pyproject.toml @@ -43,6 +43,7 @@ build-backend = "setuptools.build_meta" markers = [ "polars: integration tests with polars", ] +xfail_strict = true [tool.setuptools] include-package-data = true diff --git a/python/adbc_driver_snowflake/pyproject.toml b/python/adbc_driver_snowflake/pyproject.toml index 3f64e831da..884535fd2b 100644 --- a/python/adbc_driver_snowflake/pyproject.toml +++ b/python/adbc_driver_snowflake/pyproject.toml @@ -44,3 +44,6 @@ include-package-data = true license-files = ["LICENSE.txt", "NOTICE.txt"] packages = ["adbc_driver_snowflake"] py-modules = ["adbc_driver_snowflake"] + +[tool.pytest.ini_options] +xfail_strict = true diff --git a/python/adbc_driver_sqlite/pyproject.toml b/python/adbc_driver_sqlite/pyproject.toml index 3f7acc01bf..282cccf4ec 100644 --- a/python/adbc_driver_sqlite/pyproject.toml +++ b/python/adbc_driver_sqlite/pyproject.toml @@ -44,3 +44,6 @@ include-package-data = true license-files = ["LICENSE.txt", "NOTICE.txt"] packages = ["adbc_driver_sqlite"] py-modules = ["adbc_driver_sqlite"] + +[tool.pytest.ini_options] +xfail_strict = true