Skip to content

Commit

Permalink
fix(go/adbc/driver/flightsql): use atomic for progress (#1520)
Browse files Browse the repository at this point in the history
Fixes #1504.
  • Loading branch information
lidavidm authored Feb 6, 2024
1 parent f5445ae commit f01ee5f
Showing 1 changed file with 16 additions and 5 deletions.
21 changes: 16 additions & 5 deletions go/adbc/driver/flightsql/flightsql_statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@ package flightsql
import (
"context"
"fmt"
"math"
"strconv"
"strings"
"sync/atomic"
"time"
"unsafe"

"github.com/apache/arrow-adbc/go/adbc"
"github.com/apache/arrow/go/v16/arrow"
Expand All @@ -45,6 +48,14 @@ const (
OptionStatementSubstraitVersion = "adbc.flight.sql.substrait.version"
)

func atomicLoadFloat64(x *float64) float64 {
return math.Float64frombits(atomic.LoadUint64((*uint64)(unsafe.Pointer(x))))
}

func atomicStoreFloat64(x *float64, v float64) {
atomic.StoreUint64((*uint64)(unsafe.Pointer(x)), math.Float64bits(v))
}

type sqlOrSubstrait struct {
sqlQuery string
substraitPlan []byte
Expand Down Expand Up @@ -271,7 +282,7 @@ func (s *statement) GetOptionDouble(key string) (float64, error) {
case OptionTimeoutUpdate:
return s.timeouts.updateTimeout.Seconds(), nil
case adbc.OptionKeyProgress:
return s.progress, nil
return atomicLoadFloat64(&s.progress), nil
case adbc.OptionKeyMaxProgress:
return 1.0, nil
}
Expand Down Expand Up @@ -582,7 +593,7 @@ func (s *statement) ExecutePartitions(ctx context.Context) (*arrow.Schema, adbc.
totalRecords := s.incrementalState.previousInfo.TotalRecords
// Reset the statement for reuse
s.incrementalState = &incrementalState{}
s.progress = 0.0
atomicStoreFloat64(&s.progress, 0.0)
return schema, adbc.Partitions{}, totalRecords, nil
}

Expand All @@ -598,7 +609,7 @@ func (s *statement) ExecutePartitions(ctx context.Context) (*arrow.Schema, adbc.
// The server is misbehaving
// XXX: should we also issue a query cancellation?
s.incrementalState = &incrementalState{}
s.progress = 0.0
atomicStoreFloat64(&s.progress, 0.0)
return nil, adbc.Partitions{}, -1, adbc.Error{
Msg: "[Flight SQL] Server returned a PollInfo with no FlightInfo",
Code: adbc.StatusInternal,
Expand All @@ -616,7 +627,7 @@ func (s *statement) ExecutePartitions(ctx context.Context) (*arrow.Schema, adbc.
}
s.incrementalState.previousInfo = poll.GetInfo()
s.incrementalState.retryDescriptor = poll.GetFlightDescriptor()
s.progress = poll.GetProgress()
atomicStoreFloat64(&s.progress, poll.GetProgress())

if s.incrementalState.retryDescriptor == nil {
// Query is finished
Expand All @@ -639,7 +650,7 @@ func (s *statement) ExecutePartitions(ctx context.Context) (*arrow.Schema, adbc.
// returning 0 partitions implies completion)
if s.incrementalState.complete && len(info.Endpoint) == 0 {
s.incrementalState = &incrementalState{}
s.progress = 0.0
atomicStoreFloat64(&s.progress, 0.0)
}
} else if s.prepared != nil {
info, err = s.prepared.Execute(ctx, grpc.Header(&header), grpc.Trailer(&trailer), s.timeouts)
Expand Down

0 comments on commit f01ee5f

Please sign in to comment.