From bcbc1614c26ce4dcc35840b604773648502146bc Mon Sep 17 00:00:00 2001
From: David Li
Date: Fri, 1 Mar 2024 17:07:40 -0500
Subject: [PATCH] feat(go/adbc/driver/flightsql): reflect gRPC status in vendor
code (#1577)
Does not work in C/C++/Python due to #1576.
Fixes #1574.
---
.../flightsql/flightsql_adbc_server_test.go | 63 +++++++++++++++++++
go/adbc/driver/flightsql/utils.go | 34 +++++++---
python/adbc_driver_flightsql/pyproject.toml | 3 +
.../tests/test_errors.py | 15 +++++
python/adbc_driver_manager/pyproject.toml | 1 +
python/adbc_driver_postgresql/pyproject.toml | 1 +
python/adbc_driver_snowflake/pyproject.toml | 3 +
python/adbc_driver_sqlite/pyproject.toml | 3 +
8 files changed, 116 insertions(+), 7 deletions(-)
diff --git a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
index f779e6aff4..66e94da44e 100644
--- a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
+++ b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
@@ -253,6 +253,16 @@ func (srv *ErrorDetailsTestServer) GetFlightInfoStatement(ctx context.Context, q
panic(err)
}
return &flight.FlightInfo{Endpoint: []*flight.FlightEndpoint{{Ticket: &flight.Ticket{Ticket: tkt}}}}, nil
+ } else if query.GetQuery() == "vendorcode" {
+ return nil, status.Errorf(codes.ResourceExhausted, "Resource exhausted")
+ } else if query.GetQuery() == "binaryheader" {
+ if err := grpc.SendHeader(ctx, metadata.Pairs("x-header-bin", string([]byte{0, 110}))); err != nil {
+ return nil, err
+ }
+ if err := grpc.SetTrailer(ctx, metadata.Pairs("x-trailer-bin", string([]byte{111, 0, 112}))); err != nil {
+ return nil, err
+ }
+ return nil, status.Errorf(codes.FailedPrecondition, "Resource exhausted")
}
return nil, status.Errorf(codes.Unimplemented, "GetSchemaStatement not implemented")
}
@@ -287,6 +297,43 @@ func (suite *ErrorDetailsTests) SetupSuite() {
suite.DoSetupSuite(&srv, nil, nil)
}
+func (ts *ErrorDetailsTests) TestBinaryDetails() {
+ stmt, err := ts.cnxn.NewStatement()
+ ts.NoError(err)
+ defer stmt.Close()
+
+ ts.NoError(stmt.SetSqlQuery("binaryheader"))
+
+ _, _, err = stmt.ExecuteQuery(context.Background())
+ var adbcErr adbc.Error
+ ts.ErrorAs(err, &adbcErr)
+
+ ts.Equal(int32(codes.FailedPrecondition), adbcErr.VendorCode)
+
+ ts.Equal(2, len(adbcErr.Details))
+
+ headerFound := false
+ trailerFound := false
+ for _, wrapper := range adbcErr.Details {
+ switch wrapper.Key() {
+ case "x-header-bin":
+ val, err := wrapper.Serialize()
+ ts.NoError(err)
+ ts.Equal([]byte{0, 110}, val)
+ headerFound = true
+ case "x-trailer-bin":
+ val, err := wrapper.Serialize()
+ ts.NoError(err)
+ ts.Equal([]byte{111, 0, 112}, val)
+ trailerFound = true
+ default:
+ ts.Failf("Unexpected detail key: %s", wrapper.Key())
+ }
+ }
+ ts.Truef(headerFound, "Did not find x-header-bin")
+ ts.Truef(trailerFound, "Did not find x-trailer-bin")
+}
+
func (ts *ErrorDetailsTests) TestGetFlightInfo() {
stmt, err := ts.cnxn.NewStatement()
ts.NoError(err)
@@ -298,6 +345,8 @@ func (ts *ErrorDetailsTests) TestGetFlightInfo() {
var adbcErr adbc.Error
ts.ErrorAs(err, &adbcErr)
+ ts.Equal(int32(codes.Unknown), adbcErr.VendorCode)
+
ts.Equal(1, len(adbcErr.Details))
wrapper := adbcErr.Details[0]
@@ -347,6 +396,20 @@ func (ts *ErrorDetailsTests) TestDoGet() {
ts.Equal(int32(42), message.Value)
}
+func (ts *ErrorDetailsTests) TestVendorCode() {
+ stmt, err := ts.cnxn.NewStatement()
+ ts.NoError(err)
+ defer stmt.Close()
+
+ ts.NoError(stmt.SetSqlQuery("vendorcode"))
+
+ _, _, err = stmt.ExecuteQuery(context.Background())
+ var adbcErr adbc.Error
+ ts.ErrorAs(err, &adbcErr)
+
+ ts.Equal(int32(codes.ResourceExhausted), adbcErr.VendorCode)
+}
+
// ---- ExecuteSchema Tests --------------------
type ExecuteSchemaTestServer struct {
diff --git a/go/adbc/driver/flightsql/utils.go b/go/adbc/driver/flightsql/utils.go
index fef8e73829..ea063ad9d8 100644
--- a/go/adbc/driver/flightsql/utils.go
+++ b/go/adbc/driver/flightsql/utils.go
@@ -20,6 +20,7 @@ package flightsql
import (
"context"
"fmt"
+ "strings"
"github.com/apache/arrow-adbc/go/adbc"
"google.golang.org/grpc/codes"
@@ -92,9 +93,18 @@ func adbcFromFlightStatusWithDetails(err error, header, trailer metadata.MD, con
// XXX: must check both headers and trailers because some implementations
// (like gRPC-Java) will consolidate trailers into headers for failed RPCs
for key, values := range header {
- switch key {
- case "content-type", "grpc-status-details-bin":
+ switch {
+ case key == "content-type":
+ // Not useful info
continue
+ case key == "grpc-status-details-bin":
+ // gRPC library parses this above via grpcStatus.Proto()
+ continue
+ case strings.HasSuffix(key, "-bin"):
+ for _, value := range values {
+ // that's right, gRPC stuffs binary data into a "string"
+ details = append(details, &adbc.BinaryErrorDetail{Name: key, Detail: []byte(value)})
+ }
default:
for _, value := range values {
details = append(details, &adbc.TextErrorDetail{Name: key, Detail: value})
@@ -102,9 +112,18 @@ func adbcFromFlightStatusWithDetails(err error, header, trailer metadata.MD, con
}
}
for key, values := range trailer {
- switch key {
- case "content-type", "grpc-status-details-bin":
+ switch {
+ case key == "content-type":
+ // Not useful info
continue
+ case key == "grpc-status-details-bin":
+ // gRPC library parses this above via grpcStatus.Proto()
+ continue
+ case strings.HasSuffix(key, "-bin"):
+ for _, value := range values {
+ // that's right, gRPC stuffs binary data into a "string"
+ details = append(details, &adbc.BinaryErrorDetail{Name: key, Detail: []byte(value)})
+ }
default:
for _, value := range values {
details = append(details, &adbc.TextErrorDetail{Name: key, Detail: value})
@@ -114,9 +133,10 @@ func adbcFromFlightStatusWithDetails(err error, header, trailer metadata.MD, con
return adbc.Error{
// People don't read error messages, so backload the context and frontload the server error
- Msg: fmt.Sprintf("[FlightSQL] %s (%s; %s)", grpcStatus.Message(), grpcStatus.Code(), fmt.Sprintf(context, args...)),
- Code: adbcCode,
- Details: details,
+ Msg: fmt.Sprintf("[FlightSQL] %s (%s; %s)", grpcStatus.Message(), grpcStatus.Code(), fmt.Sprintf(context, args...)),
+ Code: adbcCode,
+ VendorCode: int32(grpcStatus.Code()),
+ Details: details,
}
}
diff --git a/python/adbc_driver_flightsql/pyproject.toml b/python/adbc_driver_flightsql/pyproject.toml
index ba3485f9f5..35362655ee 100644
--- a/python/adbc_driver_flightsql/pyproject.toml
+++ b/python/adbc_driver_flightsql/pyproject.toml
@@ -44,3 +44,6 @@ include-package-data = true
license-files = ["LICENSE.txt", "NOTICE.txt"]
packages = ["adbc_driver_flightsql"]
py-modules = ["adbc_driver_flightsql"]
+
+[tool.pytest.ini_options]
+xfail_strict = true
diff --git a/python/adbc_driver_flightsql/tests/test_errors.py b/python/adbc_driver_flightsql/tests/test_errors.py
index ee2b62d3ee..bed1878fc7 100644
--- a/python/adbc_driver_flightsql/tests/test_errors.py
+++ b/python/adbc_driver_flightsql/tests/test_errors.py
@@ -84,6 +84,21 @@ def test_query_error_fetch(test_dbapi):
assert_detail(excval.value)
+@pytest.mark.xfail(reason="apache/arrow-adbc#1576")
+def test_query_error_vendor_code(test_dbapi):
+ with test_dbapi.cursor() as cur:
+ cur.execute("error_do_get")
+ with pytest.raises(
+ test_dbapi.ProgrammingError,
+ match=re.escape("INVALID_ARGUMENT: [FlightSQL] expected error (DoGet)"),
+ ) as excval:
+ cur.fetch_arrow_table()
+
+ # TODO(https://github.com/apache/arrow-adbc/issues/1576): vendor code
+ # is gRPC status code; 3 is gRPC INVALID_ARGUMENT
+ assert excval.value.vendor_code == 3
+
+
def test_query_error_stream(test_dbapi):
with test_dbapi.cursor() as cur:
cur.execute("error_do_get_stream")
diff --git a/python/adbc_driver_manager/pyproject.toml b/python/adbc_driver_manager/pyproject.toml
index d2db1f102f..16f7d00eed 100644
--- a/python/adbc_driver_manager/pyproject.toml
+++ b/python/adbc_driver_manager/pyproject.toml
@@ -43,6 +43,7 @@ markers = [
"panicdummy: tests that require the testing-only panicdummy driver",
"sqlite: tests that require the SQLite driver",
]
+xfail_strict = true
[tool.setuptools]
license-files = ["LICENSE.txt", "NOTICE.txt"]
diff --git a/python/adbc_driver_postgresql/pyproject.toml b/python/adbc_driver_postgresql/pyproject.toml
index a5acc22ea4..298c8b0cb5 100644
--- a/python/adbc_driver_postgresql/pyproject.toml
+++ b/python/adbc_driver_postgresql/pyproject.toml
@@ -43,6 +43,7 @@ build-backend = "setuptools.build_meta"
markers = [
"polars: integration tests with polars",
]
+xfail_strict = true
[tool.setuptools]
include-package-data = true
diff --git a/python/adbc_driver_snowflake/pyproject.toml b/python/adbc_driver_snowflake/pyproject.toml
index 3f64e831da..884535fd2b 100644
--- a/python/adbc_driver_snowflake/pyproject.toml
+++ b/python/adbc_driver_snowflake/pyproject.toml
@@ -44,3 +44,6 @@ include-package-data = true
license-files = ["LICENSE.txt", "NOTICE.txt"]
packages = ["adbc_driver_snowflake"]
py-modules = ["adbc_driver_snowflake"]
+
+[tool.pytest.ini_options]
+xfail_strict = true
diff --git a/python/adbc_driver_sqlite/pyproject.toml b/python/adbc_driver_sqlite/pyproject.toml
index 3f7acc01bf..282cccf4ec 100644
--- a/python/adbc_driver_sqlite/pyproject.toml
+++ b/python/adbc_driver_sqlite/pyproject.toml
@@ -44,3 +44,6 @@ include-package-data = true
license-files = ["LICENSE.txt", "NOTICE.txt"]
packages = ["adbc_driver_sqlite"]
py-modules = ["adbc_driver_sqlite"]
+
+[tool.pytest.ini_options]
+xfail_strict = true