From a17229f16362826698bfbc3663882bcefbe79013 Mon Sep 17 00:00:00 2001 From: Madhavendra Rathore Date: Wed, 3 Jun 2026 00:23:59 +0530 Subject: [PATCH] Cap CloudFetch Arrow batches to server-declared RowCount CloudFetch Arrow IPC files can contain padding rows beyond the RowCount declared on each result link, which the driver surfaced as extra rows (e.g. 301,407 rows returned for SELECT ... LIMIT 300000). Cap each decoded CloudFetch batch to its link's RowCount and anchor batch offsets to the link's StartRowOffset. This matches the official JDBC driver, whose ArrowResultChunkIterator stops iterating once rowsReadByIterator >= numRows (the server-declared TSparkArrowResultLink.RowCount), silently ignoring any padding rows in the Arrow file. The cap is scoped to the CloudFetch path only, via a new positionedIPCStreamIterator interface implemented solely by cloudIPCStreamIterator. The inline/local Arrow path is intentionally left uncapped: those batches are returned verbatim with no padding and their per-batch RowCount has historically been untrusted, so capping there could silently drop rows. A RowCount <= 0 is treated as "unknown" and never drops rows. Adds a unit test driving batchIterator -> limitArrowRecords (cap-down, exact boundary, over-count, RowCount==0 safety, inline-never-capped) and an env-gated CloudFetch e2e asserting an exact 2,000,000-row drain. Co-authored-by: Isaac Signed-off-by: Madhavendra Rathore --- driver_e2e_test.go | 68 +++++++++++ internal/rows/arrowbased/batchloader.go | 114 +++++++++++++++++-- internal/rows/arrowbased/batchloader_test.go | 102 +++++++++++++++++ 3 files changed, 277 insertions(+), 7 deletions(-) diff --git a/driver_e2e_test.go b/driver_e2e_test.go index 918f758f..db358024 100644 --- a/driver_e2e_test.go +++ b/driver_e2e_test.go @@ -657,3 +657,71 @@ func getServer(state *callState) *httptest.Server { }, }) } + +// TestE2ECloudFetchExactRowCount validates that a large CloudFetch result drains +// the EXACT number of rows requested. CloudFetch Arrow IPC files can carry padding +// rows beyond a link's server-declared RowCount; without capping to RowCount the +// driver over-reports (e.g. 301,407 rows for a LIMIT 300000). This is the +// regression guard for the row-count cap. Skipped in -short mode because it +// drains a multi-million-row result over several CloudFetch link pages. +func TestE2ECloudFetchExactRowCount(t *testing.T) { + if testing.Short() { + t.Skip("skipping large CloudFetch drain in -short mode") + } + host := os.Getenv("DATABRICKS_PECOTESTING_SERVER_HOSTNAME") + httpPath := os.Getenv("DATABRICKS_PECOTESTING_HTTP_PATH2") + token := os.Getenv("DATABRICKS_PECOTESTING_TOKEN") + if token == "" { + token = os.Getenv("DATABRICKS_PECOTESTING_TOKEN_PERSONAL") + } + if host == "" || httpPath == "" || token == "" { + t.Skip("set DATABRICKS_PECOTESTING_SERVER_HOSTNAME, DATABRICKS_PECOTESTING_HTTP_PATH2, and DATABRICKS_PECOTESTING_TOKEN to run") + } + + const wantRows = 2000000 + + connector, err := NewConnector( + WithServerHostname(host), + WithPort(443), + WithHTTPPath(httpPath), + WithAccessToken(token), + WithMaxRows(500000), + ) + require.NoError(t, err) + + db := sql.OpenDB(connector) + defer db.Close() //nolint:errcheck + + conn, err := db.Conn(context.Background()) + require.NoError(t, err) + defer conn.Close() //nolint:errcheck + + // A wide-ish row (id + 64-byte pad) over 2M rows forces a multi-page + // CloudFetch (URL-based) result rather than inline Arrow. + query := fmt.Sprintf("SELECT id, repeat('x', 64) AS pad FROM range(%d)", wantRows) + var driverRows driver.Rows + err = conn.Raw(func(d any) error { + var queryErr error + driverRows, queryErr = d.(driver.QueryerContext).QueryContext(context.Background(), query, nil) + return queryErr + }) + require.NoError(t, err) + defer driverRows.Close() //nolint:errcheck + + batches, err := driverRows.(dbsqlrows.Rows).GetArrowBatches(context.Background()) + require.NoError(t, err) + defer batches.Close() + + var rowCount int64 + for { + record, nextErr := batches.Next() + if nextErr == io.EOF { + break + } + require.NoError(t, nextErr) + rowCount += record.NumRows() + record.Release() + } + + require.Equal(t, int64(wantRows), rowCount, "CloudFetch must surface exactly the requested rows, with no Arrow padding") +} diff --git a/internal/rows/arrowbased/batchloader.go b/internal/rows/arrowbased/batchloader.go index 889524d3..0a8a248e 100644 --- a/internal/rows/arrowbased/batchloader.go +++ b/internal/rows/arrowbased/batchloader.go @@ -31,6 +31,27 @@ type IPCStreamIterator interface { Close() } +// positionedIPCStreamIterator is an optional extension of IPCStreamIterator for +// streams that carry server-declared positioning metadata alongside each IPC +// payload. It is implemented ONLY by the CloudFetch iterator: CloudFetch result +// links carry an authoritative StartRowOffset and RowCount, and the Arrow IPC +// files they point at may be padded with extra rows beyond RowCount. batchIterator +// uses this metadata to (a) anchor each batch at its true stream offset and +// (b) cap the decoded records to RowCount so padding rows are not surfaced as +// real data (see limitArrowRecords). +// +// The inline/local Arrow path intentionally does NOT implement this: those +// batches are returned verbatim by the server with no padding, and their +// per-batch RowCount has historically been untrusted, so capping there would +// risk silently dropping rows. NextWithMetadata returns expectedRows < 0 to mean +// "row count unknown — do not cap". +type positionedIPCStreamIterator interface { + // NextWithMetadata returns the next IPC payload along with its absolute + // stream start offset and the server-declared row count (expectedRows). An + // expectedRows < 0 means the count is unknown and no capping should occur. + NextWithMetadata() (reader io.Reader, startRowOffset int64, expectedRows int64, err error) +} + func NewCloudIPCStreamIterator( ctx context.Context, files []*cli_service.TSparkArrowResultLink, @@ -174,8 +195,17 @@ type cloudIPCStreamIterator struct { } var _ IPCStreamIterator = (*cloudIPCStreamIterator)(nil) +var _ positionedIPCStreamIterator = (*cloudIPCStreamIterator)(nil) func (bi *cloudIPCStreamIterator) Next() (io.Reader, error) { + reader, _, _, err := bi.NextWithMetadata() + return reader, err +} + +// NextWithMetadata returns the next downloaded CloudFetch IPC payload together +// with the link's authoritative StartRowOffset and RowCount. The Arrow file may +// contain padding rows beyond RowCount; the caller caps to RowCount. +func (bi *cloudIPCStreamIterator) NextWithMetadata() (io.Reader, int64, int64, error) { for (bi.downloadTasks.Len() < bi.cfg.MaxDownloadThreads) && (bi.pendingLinks.Len() > 0) { link := bi.pendingLinks.Dequeue() logger.Debug().Msgf( @@ -204,7 +234,7 @@ func (bi *cloudIPCStreamIterator) Next() (io.Reader, error) { task := bi.downloadTasks.Dequeue() if task == nil { - return nil, io.EOF + return nil, 0, 0, io.EOF } data, downloadMs, err := task.GetResult() @@ -212,7 +242,7 @@ func (bi *cloudIPCStreamIterator) Next() (io.Reader, error) { // once we've got an errored out task - cancel the remaining ones if err != nil { bi.Close() - return nil, err + return nil, 0, 0, err } // explicitly call cancel function on successfully completed task to avoid context leak @@ -226,7 +256,7 @@ func (bi *cloudIPCStreamIterator) Next() (io.Reader, error) { bi.onFileDownloaded(downloadMs) } - return data, nil + return data, task.link.StartRowOffset, task.link.RowCount, nil } func (bi *cloudIPCStreamIterator) HasNext() bool { @@ -558,16 +588,38 @@ func NewBatchIterator(ipcIterator IPCStreamIterator, startRowOffset int64) Batch } func (bi *batchIterator) Next() (SparkArrowBatch, error) { - reader, err := bi.ipcIterator.Next() + // startRowOffset is the absolute offset of this batch within the result + // stream. For positioned (CloudFetch) streams it comes from the server's + // link metadata; otherwise we track it locally by accumulating decoded rows. + startRowOffset := bi.startRowOffset + // expectedRows is the server-declared row count for this batch. A value < 0 + // means "unknown" and disables capping (the inline/local path). + expectedRows := int64(-1) + var reader io.Reader + var err error + if positionedIterator, ok := bi.ipcIterator.(positionedIPCStreamIterator); ok { + reader, startRowOffset, expectedRows, err = positionedIterator.NextWithMetadata() + } else { + reader, err = bi.ipcIterator.Next() + } if err != nil { return nil, err } - records, err := getArrowRecords(reader, bi.startRowOffset) + records, err := getArrowRecords(reader, startRowOffset) if err != nil { return nil, err } + // Cap the decoded records to the server-declared row count, dropping the + // padding rows some CloudFetch Arrow files carry beyond their link's + // RowCount. Only cap when the count is strictly positive: expectedRows == 0 + // with decoded rows is treated as "untrustworthy, do not cap" rather than + // silently dropping the whole batch (see #371 review F1). + if expectedRows > 0 { + records = limitArrowRecords(records, expectedRows) + } + // When using CloudFetch, cached Arrow IPC files may contain stale column // names from a previous query. Replace the embedded schema with the // authoritative schema from GetResultSetMetadata. @@ -593,14 +645,62 @@ func (bi *batchIterator) Next() (SparkArrowBatch, error) { } batch := &sparkArrowBatch{ - Delimiter: rowscanner.NewDelimiter(bi.startRowOffset, totalRows), + Delimiter: rowscanner.NewDelimiter(startRowOffset, totalRows), arrowRecords: records, } - bi.startRowOffset += totalRows + // Advance the local offset for the next non-positioned batch. Positioned + // streams overwrite startRowOffset from server metadata on the next call. + bi.startRowOffset = startRowOffset + totalRows return batch, nil } +// limitArrowRecords caps a decoded batch to expectedRows, releasing any records +// (and the tail of a partially-kept record) that fall beyond the server-declared +// count. It is the mechanism that strips CloudFetch Arrow padding rows. +// +// Contract: +// - Callers must only invoke this when expectedRows is trustworthy and the +// batch may be over-long; expectedRows < 0 is treated as "unknown" and the +// records are returned unchanged. +// - When a record straddles the boundary it is sliced with NewSlice(0, remaining): +// the slice bounds are record-relative (0-based within the record), while the +// Delimiter's start is the ABSOLUTE stream offset of the record. Keep these two +// distinct — do not pass the absolute start as a slice index. +func limitArrowRecords(records []SparkArrowRecord, expectedRows int64) []SparkArrowRecord { + if expectedRows < 0 { + return records + } + + remaining := expectedRows + limited := records[:0] + for _, record := range records { + if remaining <= 0 { + record.Release() + continue + } + + if record.NumRows() <= remaining { + limited = append(limited, record) + remaining -= record.NumRows() + continue + } + + start := record.Start() + sliced := record.NewSlice(0, remaining) + record.Release() + if sliced != nil { + limited = append(limited, &sparkArrowRecord{ + Delimiter: rowscanner.NewDelimiter(start, sliced.NumRows()), + Record: sliced, + }) + } + remaining = 0 + } + + return limited +} + func (bi *batchIterator) HasNext() bool { return bi.ipcIterator.HasNext() } diff --git a/internal/rows/arrowbased/batchloader_test.go b/internal/rows/arrowbased/batchloader_test.go index d66ded36..3e3e4831 100644 --- a/internal/rows/arrowbased/batchloader_test.go +++ b/internal/rows/arrowbased/batchloader_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "fmt" + "io" "net/http" "net/http/httptest" "runtime" @@ -1160,3 +1161,104 @@ func countDownloadTaskGoroutines() int { } return strings.Count(string(buf), "cloudFetchDownloadTask).Run") } + +// fakePositionedIPCIterator is a test IPCStreamIterator that also implements +// positionedIPCStreamIterator, so it exercises the CloudFetch row-count-capping +// path through NewBatchIterator without real CloudFetch downloads. +type fakePositionedIPCIterator struct { + data []byte + startRowOffset int64 + expectedRows int64 + consumed bool +} + +var _ IPCStreamIterator = (*fakePositionedIPCIterator)(nil) +var _ positionedIPCStreamIterator = (*fakePositionedIPCIterator)(nil) + +func (f *fakePositionedIPCIterator) Next() (io.Reader, error) { + r, _, _, err := f.NextWithMetadata() + return r, err +} +func (f *fakePositionedIPCIterator) NextWithMetadata() (io.Reader, int64, int64, error) { + if f.consumed { + return nil, 0, 0, io.EOF + } + f.consumed = true + return bytes.NewReader(f.data), f.startRowOffset, f.expectedRows, nil +} +func (f *fakePositionedIPCIterator) HasNext() bool { return !f.consumed } +func (f *fakePositionedIPCIterator) Close() {} + +// fakePlainIPCIterator implements only IPCStreamIterator (the inline/local +// shape) so the cap must never apply to it. +type fakePlainIPCIterator struct { + data []byte + consumed bool +} + +var _ IPCStreamIterator = (*fakePlainIPCIterator)(nil) + +func (f *fakePlainIPCIterator) Next() (io.Reader, error) { + if f.consumed { + return nil, io.EOF + } + f.consumed = true + return bytes.NewReader(f.data), nil +} +func (f *fakePlainIPCIterator) HasNext() bool { return !f.consumed } +func (f *fakePlainIPCIterator) Close() {} + +// TestBatchIterator_RowCountCap covers the batchIterator -> limitArrowRecords +// integration (#371 review F1/F6). generateMockArrowBytes writes the 3-row +// record twice, so each stream decodes to 6 rows. +func TestBatchIterator_RowCountCap(t *testing.T) { + const decoded = 6 + const startOffset int64 = 100 + + t.Run("positioned: caps padding rows down to RowCount", func(t *testing.T) { + it := &fakePositionedIPCIterator{data: generateMockArrowBytes(generateArrowRecord()), startRowOffset: startOffset, expectedRows: 4} + bi := NewBatchIterator(it, startOffset) + batch, err := bi.Next() + assert.NoError(t, err) + defer batch.Close() + assert.Equal(t, int64(4), batch.Count(), "batch should be capped to RowCount") + assert.Equal(t, startOffset, batch.Start(), "batch must anchor at the server offset") + }) + + t.Run("positioned: exact boundary keeps all rows", func(t *testing.T) { + it := &fakePositionedIPCIterator{data: generateMockArrowBytes(generateArrowRecord()), startRowOffset: startOffset, expectedRows: decoded} + bi := NewBatchIterator(it, startOffset) + batch, err := bi.Next() + assert.NoError(t, err) + defer batch.Close() + assert.Equal(t, int64(decoded), batch.Count()) + }) + + t.Run("positioned: RowCount larger than decoded keeps all rows", func(t *testing.T) { + it := &fakePositionedIPCIterator{data: generateMockArrowBytes(generateArrowRecord()), startRowOffset: startOffset, expectedRows: 100} + bi := NewBatchIterator(it, startOffset) + batch, err := bi.Next() + assert.NoError(t, err) + defer batch.Close() + assert.Equal(t, int64(decoded), batch.Count()) + }) + + t.Run("positioned: RowCount==0 is NOT trusted, keeps all rows (F1)", func(t *testing.T) { + it := &fakePositionedIPCIterator{data: generateMockArrowBytes(generateArrowRecord()), startRowOffset: startOffset, expectedRows: 0} + bi := NewBatchIterator(it, startOffset) + batch, err := bi.Next() + assert.NoError(t, err) + defer batch.Close() + assert.Equal(t, int64(decoded), batch.Count(), "RowCount==0 must not silently drop the batch") + }) + + t.Run("plain/inline iterator is never capped", func(t *testing.T) { + it := &fakePlainIPCIterator{data: generateMockArrowBytes(generateArrowRecord())} + bi := NewBatchIterator(it, startOffset) + batch, err := bi.Next() + assert.NoError(t, err) + defer batch.Close() + assert.Equal(t, int64(decoded), batch.Count(), "inline path must return all decoded rows") + assert.Equal(t, startOffset, batch.Start()) + }) +}