diff --git a/driver_e2e_test.go b/driver_e2e_test.go index fdd3a538..918f758f 100644 --- a/driver_e2e_test.go +++ b/driver_e2e_test.go @@ -3,8 +3,10 @@ package dbsql import ( "context" "database/sql" + "database/sql/driver" "encoding/json" "fmt" + "io" "net/http/httptest" "net/url" "os" @@ -17,6 +19,7 @@ import ( "github.com/databricks/databricks-sql-go/internal/cli_service" "github.com/databricks/databricks-sql-go/internal/client" "github.com/databricks/databricks-sql-go/logger" + dbsqlrows "github.com/databricks/databricks-sql-go/rows" "github.com/pkg/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -255,6 +258,68 @@ func TestWorkflowExample(t *testing.T) { } } +func TestE2EArrowBatchesSurviveQueryContextCancellation(t *testing.T) { + host := os.Getenv("DATABRICKS_PECOTESTING_SERVER_HOSTNAME") + httpPath := os.Getenv("DATABRICKS_PECOTESTING_HTTP_PATH2") + token := os.Getenv("DATABRICKS_PECOTESTING_TOKEN") + if token == "" { + token = os.Getenv("DATABRICKS_PECOTESTING_TOKEN_PERSONAL") + } + if host == "" || httpPath == "" || token == "" { + t.Skip("set DATABRICKS_PECOTESTING_SERVER_HOSTNAME, DATABRICKS_PECOTESTING_HTTP_PATH2, and DATABRICKS_PECOTESTING_TOKEN to run") + } + + connector, err := NewConnector( + WithServerHostname(host), + WithPort(443), + WithHTTPPath(httpPath), + WithAccessToken(token), + WithMaxRows(1), + ) + require.NoError(t, err) + + db := sql.OpenDB(connector) + defer db.Close() //nolint:errcheck + + conn, err := db.Conn(context.Background()) + require.NoError(t, err) + defer conn.Close() //nolint:errcheck + + queryCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var driverRows driver.Rows + err = conn.Raw(func(d any) error { + var queryErr error + driverRows, queryErr = d.(driver.QueryerContext).QueryContext(queryCtx, "SELECT id FROM range(3)", nil) + return queryErr + }) + require.NoError(t, err) + defer driverRows.Close() //nolint:errcheck + + cancel() + + // Pass the already-cancelled queryCtx (not context.Background()) so the test + // exercises the detached-iterator path: result paging AND CloudFetch + // downloads must survive cancellation of the ctx handed to GetArrowBatches. + batches, err := driverRows.(dbsqlrows.Rows).GetArrowBatches(queryCtx) + require.NoError(t, err) + defer batches.Close() + + var rowCount int64 + for { + record, nextErr := batches.Next() + if nextErr == io.EOF { + break + } + require.NoError(t, nextErr) + rowCount += record.NumRows() + record.Release() + } + + require.Equal(t, int64(3), rowCount) +} + func TestContextTimeoutExample(t *testing.T) { _ = logger.SetLogLevel("debug") diff --git a/internal/rows/rows.go b/internal/rows/rows.go index bd0c2605..c2fa530f 100644 --- a/internal/rows/rows.go +++ b/internal/rows/rows.go @@ -13,6 +13,7 @@ import ( dbsqlerr "github.com/databricks/databricks-sql-go/errors" "github.com/databricks/databricks-sql-go/internal/cli_service" dbsqlclient "github.com/databricks/databricks-sql-go/internal/client" + context2 "github.com/databricks/databricks-sql-go/internal/compat/context" "github.com/databricks/databricks-sql-go/internal/config" dbsqlerr_int "github.com/databricks/databricks-sql-go/internal/errors" "github.com/databricks/databricks-sql-go/internal/rows/arrowbased" @@ -57,7 +58,16 @@ type rows struct { logger_ *dbsqllog.DBSQLLogger + // ctx is the context used for all server-side result RPCs (FetchResults, + // GetResultSetMetadata, CloseOperation) and CloudFetch downloads. It is + // detached from the caller's QueryContext cancellation so that a deadline + // gating statement submission does not truncate result streaming, while + // preserving context values used for auth/logging. It remains abortable via + // Close() through resultsCancel. ctx context.Context + // resultsCancel aborts in-flight result RPCs/downloads when Close() is + // called, so the detached ctx never leaves an operation uncancellable. + resultsCancel context.CancelFunc // Telemetry tracking // telemetryUpdate is called after each chunk is fetched with: @@ -134,6 +144,15 @@ func NewRows( logger.Debug().Msgf("databricks: creating Rows, pageSize: %d, location: %v", pageSize, location) + // QueryContext may use a short deadline to gate statement submission and + // status polling (see ES-1934053 / #295 / #371). Result handles can outlive + // that phase, especially for paginated CloudFetch streams, so detach + // server-side result RPCs from the caller's cancellation while preserving + // context values used for auth/logging. The detached context is still wired + // to a cancel func invoked from Close(), so the result handle remains + // abortable (no uncancellable in-flight FetchResults or CloudFetch download). + resultsCtx, resultsCancel := context.WithCancel(context2.WithoutCancel(ctx)) + r := &rows{ client: client, opHandle: opHandle, @@ -142,7 +161,8 @@ func NewRows( location: location, config: config, logger_: logger, - ctx: ctx, + ctx: resultsCtx, + resultsCancel: resultsCancel, chunkCount: 0, bytesDownloaded: 0, } @@ -201,7 +221,7 @@ func NewRows( // the operations. closedOnServer := directResults != nil && directResults.CloseOperation != nil r.ResultPageIterator = rowscanner.NewResultPageIterator( - ctx, + resultsCtx, d, pageSize, opHandle, @@ -244,6 +264,12 @@ func (r *rows) Close() error { return nil } + // Release the detached results context after the close RPC runs, aborting + // any in-flight FetchResults/CloudFetch downloads still referencing it. + if r.resultsCancel != nil { + defer r.resultsCancel() + } + if r.RowScanner != nil { // make sure the row scanner frees up any resources r.RowScanner.Close() @@ -635,27 +661,34 @@ func (r *rows) logger() *dbsqllog.DBSQLLogger { } func (r *rows) GetArrowBatches(ctx context.Context) (dbsqlrows.ArrowBatchIterator, error) { - // update context with correlationId and connectionId which will be used in logging and errors - ctx = driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(ctx, r.connId), r.correlationId) + // Result fetching must outlive the caller's QueryContext deadline: both the + // inter-page FetchResults RPCs (via r.ResultPageIterator) AND the CloudFetch + // S3 downloads created from the iterator context. Build the iterator from the + // detached results context (abortable via Close) rather than the caller ctx, + // so passing a deadline-bound ctx here cannot truncate the stream. Driver + // values for logging/auth are already carried by r.ctx; re-apply the ids + // defensively. See ES-1934053 / #371. + iterCtx := driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(r.ctx, r.connId), r.correlationId) // If a row scanner exists we use it to create the iterator, that way the iterator includes // data returned as direct results if r.RowScanner != nil { - return r.RowScanner.GetArrowBatches(ctx, *r.config, r.ResultPageIterator) + return r.RowScanner.GetArrowBatches(iterCtx, *r.config, r.ResultPageIterator) } - return arrowbased.NewArrowRecordIterator(ctx, r.ResultPageIterator, nil, nil, *r.config), nil + return arrowbased.NewArrowRecordIterator(iterCtx, r.ResultPageIterator, nil, nil, *r.config), nil } func (r *rows) GetArrowIPCStreams(ctx context.Context) (dbsqlrows.ArrowIPCStreamIterator, error) { - // update context with correlationId and connectionId which will be used in logging and errors - ctx = driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(ctx, r.connId), r.correlationId) + // See GetArrowBatches: result fetching is detached from the caller ctx so a + // submit-gating deadline cannot truncate streaming; it stays abortable via Close. + iterCtx := driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(r.ctx, r.connId), r.correlationId) // If a row scanner exists we use it to create the iterator, that way the iterator includes // data returned as direct results if r.RowScanner != nil { - return r.RowScanner.GetArrowIPCStreams(ctx, *r.config, r.ResultPageIterator) + return r.RowScanner.GetArrowIPCStreams(iterCtx, *r.config, r.ResultPageIterator) } - return arrowbased.NewArrowIPCStreamIterator(ctx, r.ResultPageIterator, nil, nil, *r.config), nil + return arrowbased.NewArrowIPCStreamIterator(iterCtx, r.ResultPageIterator, nil, nil, *r.config), nil } diff --git a/internal/rows/rows_test.go b/internal/rows/rows_test.go index 8947bf05..e04f9ef5 100644 --- a/internal/rows/rows_test.go +++ b/internal/rows/rows_test.go @@ -1564,6 +1564,143 @@ func TestFetchResultPage_PropagatesGetNextPageError(t *testing.T) { assert.ErrorContains(t, actualErr, errorMsg) } +func TestNewRows_DetachesResultRPCContextFromQueryContextCancellation(t *testing.T) { + t.Parallel() + + baseCtx := driverctx.NewContextWithConnId(context.Background(), "connId") + baseCtx = driverctx.NewContextWithCorrelationId(baseCtx, "corrId") + queryCtx, cancel := context.WithCancel(baseCtx) + cancel() + + assertResultCtx := func(ctx context.Context) { + assert.NoError(t, ctx.Err(), "result RPC context should not inherit query cancellation") + assert.Equal(t, "connId", driverctx.ConnIdFromContext(ctx)) + assert.Equal(t, "corrId", driverctx.CorrelationIdFromContext(ctx)) + } + + metaCalled := false + metaFn := func(ctx context.Context, req *cli_service.TGetResultSetMetadataReq) (*cli_service.TGetResultSetMetadataResp, error) { + metaCalled = true + assertResultCtx(ctx) + return &cli_service.TGetResultSetMetadataResp{ + Status: &cli_service.TStatus{StatusCode: cli_service.TStatusCode_SUCCESS_STATUS}, + Schema: &cli_service.TTableSchema{ + Columns: []*cli_service.TColumnDesc{ + {ColumnName: "flag", Position: 0, TypeDesc: &cli_service.TTypeDesc{ + Types: []*cli_service.TTypeEntry{{ + PrimitiveEntry: &cli_service.TPrimitiveTypeEntry{Type: cli_service.TTypeId_BOOLEAN_TYPE}, + }}, + }}, + }, + }, + }, nil + } + + noMoreRows := false + fetchCalled := false + fetchFn := func(ctx context.Context, req *cli_service.TFetchResultsReq) (*cli_service.TFetchResultsResp, error) { + fetchCalled = true + assertResultCtx(ctx) + return &cli_service.TFetchResultsResp{ + Status: &cli_service.TStatus{StatusCode: cli_service.TStatusCode_SUCCESS_STATUS}, + HasMoreRows: &noMoreRows, + Results: &cli_service.TRowSet{ + StartRowOffset: 0, + Columns: []*cli_service.TColumn{ + {BoolVal: &cli_service.TBoolColumn{Values: []bool{true}}}, + }, + }, + }, nil + } + + closeCalled := false + closeFn := func(ctx context.Context, req *cli_service.TCloseOperationReq) (*cli_service.TCloseOperationResp, error) { + closeCalled = true + assertResultCtx(ctx) + return &cli_service.TCloseOperationResp{ + Status: &cli_service.TStatus{StatusCode: cli_service.TStatusCode_SUCCESS_STATUS}, + }, nil + } + + testClient := &client.TestClient{ + FnFetchResults: fetchFn, + FnGetResultSetMetadata: metaFn, + FnCloseOperation: closeFn, + } + opHandle := &cli_service.TOperationHandle{ + OperationId: &cli_service.THandleIdentifier{GUID: []byte("operation-id")}, + } + cfg := config.WithDefaults() + + dr, dbErr := NewRows(queryCtx, opHandle, testClient, cfg, nil, nil) + assert.Nil(t, dbErr) + + dest := make([]driver.Value, 1) + assert.NoError(t, dr.Next(dest)) + assert.Equal(t, true, dest[0]) + assert.True(t, fetchCalled, "FetchResults should use the detached result context") + assert.True(t, metaCalled, "GetResultSetMetadata should use the detached result context") + assert.True(t, closeCalled, "CloseOperation should use the detached result context") +} + +// TestNewRows_CloseAbortsDetachedResultContext verifies the detachment is not +// total: the result context survives the caller's QueryContext cancellation +// (so streaming is not truncated) but is still cancelled by Close(), so an +// in-flight FetchResults/CloudFetch download can never be left uncancellable. +func TestNewRows_CloseAbortsDetachedResultContext(t *testing.T) { + t.Parallel() + + queryCtx, cancel := context.WithCancel(context.Background()) + + var capturedCtx context.Context + noMoreRows := false + fetchFn := func(ctx context.Context, req *cli_service.TFetchResultsReq) (*cli_service.TFetchResultsResp, error) { + capturedCtx = ctx + return &cli_service.TFetchResultsResp{ + Status: &cli_service.TStatus{StatusCode: cli_service.TStatusCode_SUCCESS_STATUS}, + HasMoreRows: &noMoreRows, + Results: &cli_service.TRowSet{ + StartRowOffset: 0, + Columns: []*cli_service.TColumn{{BoolVal: &cli_service.TBoolColumn{Values: []bool{true}}}}, + }, + }, nil + } + metaFn := func(ctx context.Context, req *cli_service.TGetResultSetMetadataReq) (*cli_service.TGetResultSetMetadataResp, error) { + return &cli_service.TGetResultSetMetadataResp{ + Status: &cli_service.TStatus{StatusCode: cli_service.TStatusCode_SUCCESS_STATUS}, + Schema: &cli_service.TTableSchema{Columns: []*cli_service.TColumnDesc{ + {ColumnName: "flag", Position: 0, TypeDesc: &cli_service.TTypeDesc{Types: []*cli_service.TTypeEntry{{ + PrimitiveEntry: &cli_service.TPrimitiveTypeEntry{Type: cli_service.TTypeId_BOOLEAN_TYPE}, + }}}}, + }}, + }, nil + } + closeFn := func(ctx context.Context, req *cli_service.TCloseOperationReq) (*cli_service.TCloseOperationResp, error) { + // The close RPC itself must still run with a live (un-cancelled) context. + assert.NoError(t, ctx.Err(), "CloseOperation must run before the result context is cancelled") + return &cli_service.TCloseOperationResp{Status: &cli_service.TStatus{StatusCode: cli_service.TStatusCode_SUCCESS_STATUS}}, nil + } + + testClient := &client.TestClient{FnFetchResults: fetchFn, FnGetResultSetMetadata: metaFn, FnCloseOperation: closeFn} + opHandle := &cli_service.TOperationHandle{OperationId: &cli_service.THandleIdentifier{GUID: []byte("operation-id")}} + + dr, dbErr := NewRows(queryCtx, opHandle, testClient, config.WithDefaults(), nil, nil) + assert.Nil(t, dbErr) + + dest := make([]driver.Value, 1) + assert.NoError(t, dr.Next(dest)) + assert.NotNil(t, capturedCtx) + + // Caller cancels the QueryContext: result context must remain alive. + cancel() + assert.NoError(t, capturedCtx.Err(), "result context must survive QueryContext cancellation") + assert.NotNil(t, capturedCtx.Done(), "result context must be abortable (non-nil Done)") + + // Close() must cancel the detached result context so nothing is left uncancellable. + assert.NoError(t, dr.Close()) + assert.ErrorIs(t, capturedCtx.Err(), context.Canceled, "Close() should cancel the detached result context") +} + // TestRows_CloseCallback_ReceivesChunkCount verifies that when rows.Close() is called, // the closeCallback receives the correct chunkCount reflecting the number of result pages // that were fetched during iteration.