From 2e39c73523aded5c9985ddb83457071e42f0f5b2 Mon Sep 17 00:00:00 2001 From: Madhavendra Rathore Date: Wed, 3 Jun 2026 00:27:07 +0530 Subject: [PATCH] Detach result streaming from QueryContext cancellation (ES-1934053) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PR #295 threaded the caller context into NewRows and the result page iterator, so FetchResults paging and CloseOperation inherited the caller's deadline. A short QueryContext timeout meant to gate only statement submission and status polling then fired mid-stream, silently truncating large CloudFetch results: a query expected to return 29,232,004 rows returned only 2,159,144. The ArrowBatchIterator surfaced io.EOF rather than the deadline error, so the truncation was silent. Detach the result context from the caller's cancellation via context.WithoutCancel (preserving its values for auth/logging), but wire it to a cancel func invoked from Rows.Close() so in-flight FetchResults and CloudFetch downloads are never left uncancellable. GetArrowBatches / GetArrowIPCStreams now build their iterator from this detached context so CloudFetch S3 downloads also survive the caller's deadline and remain abortable via Close — the previous behaviour cancelled downloads but not paging, a half-effective split. Adds unit tests for the detached-but-abortable contract and updates the e2e regression to pass the cancelled QueryContext into GetArrowBatches. Co-authored-by: Isaac Signed-off-by: Madhavendra Rathore --- driver_e2e_test.go | 65 ++++++++++++++++++ internal/rows/rows.go | 53 +++++++++++--- internal/rows/rows_test.go | 137 +++++++++++++++++++++++++++++++++++++ 3 files changed, 245 insertions(+), 10 deletions(-) diff --git a/driver_e2e_test.go b/driver_e2e_test.go index fdd3a538..918f758f 100644 --- a/driver_e2e_test.go +++ b/driver_e2e_test.go @@ -3,8 +3,10 @@ package dbsql import ( "context" "database/sql" + "database/sql/driver" "encoding/json" "fmt" + "io" "net/http/httptest" "net/url" "os" @@ -17,6 +19,7 @@ import ( "github.com/databricks/databricks-sql-go/internal/cli_service" "github.com/databricks/databricks-sql-go/internal/client" "github.com/databricks/databricks-sql-go/logger" + dbsqlrows "github.com/databricks/databricks-sql-go/rows" "github.com/pkg/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -255,6 +258,68 @@ func TestWorkflowExample(t *testing.T) { } } +func TestE2EArrowBatchesSurviveQueryContextCancellation(t *testing.T) { + host := os.Getenv("DATABRICKS_PECOTESTING_SERVER_HOSTNAME") + httpPath := os.Getenv("DATABRICKS_PECOTESTING_HTTP_PATH2") + token := os.Getenv("DATABRICKS_PECOTESTING_TOKEN") + if token == "" { + token = os.Getenv("DATABRICKS_PECOTESTING_TOKEN_PERSONAL") + } + if host == "" || httpPath == "" || token == "" { + t.Skip("set DATABRICKS_PECOTESTING_SERVER_HOSTNAME, DATABRICKS_PECOTESTING_HTTP_PATH2, and DATABRICKS_PECOTESTING_TOKEN to run") + } + + connector, err := NewConnector( + WithServerHostname(host), + WithPort(443), + WithHTTPPath(httpPath), + WithAccessToken(token), + WithMaxRows(1), + ) + require.NoError(t, err) + + db := sql.OpenDB(connector) + defer db.Close() //nolint:errcheck + + conn, err := db.Conn(context.Background()) + require.NoError(t, err) + defer conn.Close() //nolint:errcheck + + queryCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var driverRows driver.Rows + err = conn.Raw(func(d any) error { + var queryErr error + driverRows, queryErr = d.(driver.QueryerContext).QueryContext(queryCtx, "SELECT id FROM range(3)", nil) + return queryErr + }) + require.NoError(t, err) + defer driverRows.Close() //nolint:errcheck + + cancel() + + // Pass the already-cancelled queryCtx (not context.Background()) so the test + // exercises the detached-iterator path: result paging AND CloudFetch + // downloads must survive cancellation of the ctx handed to GetArrowBatches. + batches, err := driverRows.(dbsqlrows.Rows).GetArrowBatches(queryCtx) + require.NoError(t, err) + defer batches.Close() + + var rowCount int64 + for { + record, nextErr := batches.Next() + if nextErr == io.EOF { + break + } + require.NoError(t, nextErr) + rowCount += record.NumRows() + record.Release() + } + + require.Equal(t, int64(3), rowCount) +} + func TestContextTimeoutExample(t *testing.T) { _ = logger.SetLogLevel("debug") diff --git a/internal/rows/rows.go b/internal/rows/rows.go index bd0c2605..c2fa530f 100644 --- a/internal/rows/rows.go +++ b/internal/rows/rows.go @@ -13,6 +13,7 @@ import ( dbsqlerr "github.com/databricks/databricks-sql-go/errors" "github.com/databricks/databricks-sql-go/internal/cli_service" dbsqlclient "github.com/databricks/databricks-sql-go/internal/client" + context2 "github.com/databricks/databricks-sql-go/internal/compat/context" "github.com/databricks/databricks-sql-go/internal/config" dbsqlerr_int "github.com/databricks/databricks-sql-go/internal/errors" "github.com/databricks/databricks-sql-go/internal/rows/arrowbased" @@ -57,7 +58,16 @@ type rows struct { logger_ *dbsqllog.DBSQLLogger + // ctx is the context used for all server-side result RPCs (FetchResults, + // GetResultSetMetadata, CloseOperation) and CloudFetch downloads. It is + // detached from the caller's QueryContext cancellation so that a deadline + // gating statement submission does not truncate result streaming, while + // preserving context values used for auth/logging. It remains abortable via + // Close() through resultsCancel. ctx context.Context + // resultsCancel aborts in-flight result RPCs/downloads when Close() is + // called, so the detached ctx never leaves an operation uncancellable. + resultsCancel context.CancelFunc // Telemetry tracking // telemetryUpdate is called after each chunk is fetched with: @@ -134,6 +144,15 @@ func NewRows( logger.Debug().Msgf("databricks: creating Rows, pageSize: %d, location: %v", pageSize, location) + // QueryContext may use a short deadline to gate statement submission and + // status polling (see ES-1934053 / #295 / #371). Result handles can outlive + // that phase, especially for paginated CloudFetch streams, so detach + // server-side result RPCs from the caller's cancellation while preserving + // context values used for auth/logging. The detached context is still wired + // to a cancel func invoked from Close(), so the result handle remains + // abortable (no uncancellable in-flight FetchResults or CloudFetch download). + resultsCtx, resultsCancel := context.WithCancel(context2.WithoutCancel(ctx)) + r := &rows{ client: client, opHandle: opHandle, @@ -142,7 +161,8 @@ func NewRows( location: location, config: config, logger_: logger, - ctx: ctx, + ctx: resultsCtx, + resultsCancel: resultsCancel, chunkCount: 0, bytesDownloaded: 0, } @@ -201,7 +221,7 @@ func NewRows( // the operations. closedOnServer := directResults != nil && directResults.CloseOperation != nil r.ResultPageIterator = rowscanner.NewResultPageIterator( - ctx, + resultsCtx, d, pageSize, opHandle, @@ -244,6 +264,12 @@ func (r *rows) Close() error { return nil } + // Release the detached results context after the close RPC runs, aborting + // any in-flight FetchResults/CloudFetch downloads still referencing it. + if r.resultsCancel != nil { + defer r.resultsCancel() + } + if r.RowScanner != nil { // make sure the row scanner frees up any resources r.RowScanner.Close() @@ -635,27 +661,34 @@ func (r *rows) logger() *dbsqllog.DBSQLLogger { } func (r *rows) GetArrowBatches(ctx context.Context) (dbsqlrows.ArrowBatchIterator, error) { - // update context with correlationId and connectionId which will be used in logging and errors - ctx = driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(ctx, r.connId), r.correlationId) + // Result fetching must outlive the caller's QueryContext deadline: both the + // inter-page FetchResults RPCs (via r.ResultPageIterator) AND the CloudFetch + // S3 downloads created from the iterator context. Build the iterator from the + // detached results context (abortable via Close) rather than the caller ctx, + // so passing a deadline-bound ctx here cannot truncate the stream. Driver + // values for logging/auth are already carried by r.ctx; re-apply the ids + // defensively. See ES-1934053 / #371. + iterCtx := driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(r.ctx, r.connId), r.correlationId) // If a row scanner exists we use it to create the iterator, that way the iterator includes // data returned as direct results if r.RowScanner != nil { - return r.RowScanner.GetArrowBatches(ctx, *r.config, r.ResultPageIterator) + return r.RowScanner.GetArrowBatches(iterCtx, *r.config, r.ResultPageIterator) } - return arrowbased.NewArrowRecordIterator(ctx, r.ResultPageIterator, nil, nil, *r.config), nil + return arrowbased.NewArrowRecordIterator(iterCtx, r.ResultPageIterator, nil, nil, *r.config), nil } func (r *rows) GetArrowIPCStreams(ctx context.Context) (dbsqlrows.ArrowIPCStreamIterator, error) { - // update context with correlationId and connectionId which will be used in logging and errors - ctx = driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(ctx, r.connId), r.correlationId) + // See GetArrowBatches: result fetching is detached from the caller ctx so a + // submit-gating deadline cannot truncate streaming; it stays abortable via Close. + iterCtx := driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(r.ctx, r.connId), r.correlationId) // If a row scanner exists we use it to create the iterator, that way the iterator includes // data returned as direct results if r.RowScanner != nil { - return r.RowScanner.GetArrowIPCStreams(ctx, *r.config, r.ResultPageIterator) + return r.RowScanner.GetArrowIPCStreams(iterCtx, *r.config, r.ResultPageIterator) } - return arrowbased.NewArrowIPCStreamIterator(ctx, r.ResultPageIterator, nil, nil, *r.config), nil + return arrowbased.NewArrowIPCStreamIterator(iterCtx, r.ResultPageIterator, nil, nil, *r.config), nil } diff --git a/internal/rows/rows_test.go b/internal/rows/rows_test.go index 8947bf05..e04f9ef5 100644 --- a/internal/rows/rows_test.go +++ b/internal/rows/rows_test.go @@ -1564,6 +1564,143 @@ func TestFetchResultPage_PropagatesGetNextPageError(t *testing.T) { assert.ErrorContains(t, actualErr, errorMsg) } +func TestNewRows_DetachesResultRPCContextFromQueryContextCancellation(t *testing.T) { + t.Parallel() + + baseCtx := driverctx.NewContextWithConnId(context.Background(), "connId") + baseCtx = driverctx.NewContextWithCorrelationId(baseCtx, "corrId") + queryCtx, cancel := context.WithCancel(baseCtx) + cancel() + + assertResultCtx := func(ctx context.Context) { + assert.NoError(t, ctx.Err(), "result RPC context should not inherit query cancellation") + assert.Equal(t, "connId", driverctx.ConnIdFromContext(ctx)) + assert.Equal(t, "corrId", driverctx.CorrelationIdFromContext(ctx)) + } + + metaCalled := false + metaFn := func(ctx context.Context, req *cli_service.TGetResultSetMetadataReq) (*cli_service.TGetResultSetMetadataResp, error) { + metaCalled = true + assertResultCtx(ctx) + return &cli_service.TGetResultSetMetadataResp{ + Status: &cli_service.TStatus{StatusCode: cli_service.TStatusCode_SUCCESS_STATUS}, + Schema: &cli_service.TTableSchema{ + Columns: []*cli_service.TColumnDesc{ + {ColumnName: "flag", Position: 0, TypeDesc: &cli_service.TTypeDesc{ + Types: []*cli_service.TTypeEntry{{ + PrimitiveEntry: &cli_service.TPrimitiveTypeEntry{Type: cli_service.TTypeId_BOOLEAN_TYPE}, + }}, + }}, + }, + }, + }, nil + } + + noMoreRows := false + fetchCalled := false + fetchFn := func(ctx context.Context, req *cli_service.TFetchResultsReq) (*cli_service.TFetchResultsResp, error) { + fetchCalled = true + assertResultCtx(ctx) + return &cli_service.TFetchResultsResp{ + Status: &cli_service.TStatus{StatusCode: cli_service.TStatusCode_SUCCESS_STATUS}, + HasMoreRows: &noMoreRows, + Results: &cli_service.TRowSet{ + StartRowOffset: 0, + Columns: []*cli_service.TColumn{ + {BoolVal: &cli_service.TBoolColumn{Values: []bool{true}}}, + }, + }, + }, nil + } + + closeCalled := false + closeFn := func(ctx context.Context, req *cli_service.TCloseOperationReq) (*cli_service.TCloseOperationResp, error) { + closeCalled = true + assertResultCtx(ctx) + return &cli_service.TCloseOperationResp{ + Status: &cli_service.TStatus{StatusCode: cli_service.TStatusCode_SUCCESS_STATUS}, + }, nil + } + + testClient := &client.TestClient{ + FnFetchResults: fetchFn, + FnGetResultSetMetadata: metaFn, + FnCloseOperation: closeFn, + } + opHandle := &cli_service.TOperationHandle{ + OperationId: &cli_service.THandleIdentifier{GUID: []byte("operation-id")}, + } + cfg := config.WithDefaults() + + dr, dbErr := NewRows(queryCtx, opHandle, testClient, cfg, nil, nil) + assert.Nil(t, dbErr) + + dest := make([]driver.Value, 1) + assert.NoError(t, dr.Next(dest)) + assert.Equal(t, true, dest[0]) + assert.True(t, fetchCalled, "FetchResults should use the detached result context") + assert.True(t, metaCalled, "GetResultSetMetadata should use the detached result context") + assert.True(t, closeCalled, "CloseOperation should use the detached result context") +} + +// TestNewRows_CloseAbortsDetachedResultContext verifies the detachment is not +// total: the result context survives the caller's QueryContext cancellation +// (so streaming is not truncated) but is still cancelled by Close(), so an +// in-flight FetchResults/CloudFetch download can never be left uncancellable. +func TestNewRows_CloseAbortsDetachedResultContext(t *testing.T) { + t.Parallel() + + queryCtx, cancel := context.WithCancel(context.Background()) + + var capturedCtx context.Context + noMoreRows := false + fetchFn := func(ctx context.Context, req *cli_service.TFetchResultsReq) (*cli_service.TFetchResultsResp, error) { + capturedCtx = ctx + return &cli_service.TFetchResultsResp{ + Status: &cli_service.TStatus{StatusCode: cli_service.TStatusCode_SUCCESS_STATUS}, + HasMoreRows: &noMoreRows, + Results: &cli_service.TRowSet{ + StartRowOffset: 0, + Columns: []*cli_service.TColumn{{BoolVal: &cli_service.TBoolColumn{Values: []bool{true}}}}, + }, + }, nil + } + metaFn := func(ctx context.Context, req *cli_service.TGetResultSetMetadataReq) (*cli_service.TGetResultSetMetadataResp, error) { + return &cli_service.TGetResultSetMetadataResp{ + Status: &cli_service.TStatus{StatusCode: cli_service.TStatusCode_SUCCESS_STATUS}, + Schema: &cli_service.TTableSchema{Columns: []*cli_service.TColumnDesc{ + {ColumnName: "flag", Position: 0, TypeDesc: &cli_service.TTypeDesc{Types: []*cli_service.TTypeEntry{{ + PrimitiveEntry: &cli_service.TPrimitiveTypeEntry{Type: cli_service.TTypeId_BOOLEAN_TYPE}, + }}}}, + }}, + }, nil + } + closeFn := func(ctx context.Context, req *cli_service.TCloseOperationReq) (*cli_service.TCloseOperationResp, error) { + // The close RPC itself must still run with a live (un-cancelled) context. + assert.NoError(t, ctx.Err(), "CloseOperation must run before the result context is cancelled") + return &cli_service.TCloseOperationResp{Status: &cli_service.TStatus{StatusCode: cli_service.TStatusCode_SUCCESS_STATUS}}, nil + } + + testClient := &client.TestClient{FnFetchResults: fetchFn, FnGetResultSetMetadata: metaFn, FnCloseOperation: closeFn} + opHandle := &cli_service.TOperationHandle{OperationId: &cli_service.THandleIdentifier{GUID: []byte("operation-id")}} + + dr, dbErr := NewRows(queryCtx, opHandle, testClient, config.WithDefaults(), nil, nil) + assert.Nil(t, dbErr) + + dest := make([]driver.Value, 1) + assert.NoError(t, dr.Next(dest)) + assert.NotNil(t, capturedCtx) + + // Caller cancels the QueryContext: result context must remain alive. + cancel() + assert.NoError(t, capturedCtx.Err(), "result context must survive QueryContext cancellation") + assert.NotNil(t, capturedCtx.Done(), "result context must be abortable (non-nil Done)") + + // Close() must cancel the detached result context so nothing is left uncancellable. + assert.NoError(t, dr.Close()) + assert.ErrorIs(t, capturedCtx.Err(), context.Canceled, "Close() should cancel the detached result context") +} + // TestRows_CloseCallback_ReceivesChunkCount verifies that when rows.Close() is called, // the closeCallback receives the correct chunkCount reflecting the number of result pages // that were fetched during iteration.