diff --git a/driver_e2e_test.go b/driver_e2e_test.go index 918f758f..db358024 100644 --- a/driver_e2e_test.go +++ b/driver_e2e_test.go @@ -657,3 +657,71 @@ func getServer(state *callState) *httptest.Server { }, }) } + +// TestE2ECloudFetchExactRowCount validates that a large CloudFetch result drains +// the EXACT number of rows requested. CloudFetch Arrow IPC files can carry padding +// rows beyond a link's server-declared RowCount; without capping to RowCount the +// driver over-reports (e.g. 301,407 rows for a LIMIT 300000). This is the +// regression guard for the row-count cap. Skipped in -short mode because it +// drains a multi-million-row result over several CloudFetch link pages. +func TestE2ECloudFetchExactRowCount(t *testing.T) { + if testing.Short() { + t.Skip("skipping large CloudFetch drain in -short mode") + } + host := os.Getenv("DATABRICKS_PECOTESTING_SERVER_HOSTNAME") + httpPath := os.Getenv("DATABRICKS_PECOTESTING_HTTP_PATH2") + token := os.Getenv("DATABRICKS_PECOTESTING_TOKEN") + if token == "" { + token = os.Getenv("DATABRICKS_PECOTESTING_TOKEN_PERSONAL") + } + if host == "" || httpPath == "" || token == "" { + t.Skip("set DATABRICKS_PECOTESTING_SERVER_HOSTNAME, DATABRICKS_PECOTESTING_HTTP_PATH2, and DATABRICKS_PECOTESTING_TOKEN to run") + } + + const wantRows = 2000000 + + connector, err := NewConnector( + WithServerHostname(host), + WithPort(443), + WithHTTPPath(httpPath), + WithAccessToken(token), + WithMaxRows(500000), + ) + require.NoError(t, err) + + db := sql.OpenDB(connector) + defer db.Close() //nolint:errcheck + + conn, err := db.Conn(context.Background()) + require.NoError(t, err) + defer conn.Close() //nolint:errcheck + + // A wide-ish row (id + 64-byte pad) over 2M rows forces a multi-page + // CloudFetch (URL-based) result rather than inline Arrow. + query := fmt.Sprintf("SELECT id, repeat('x', 64) AS pad FROM range(%d)", wantRows) + var driverRows driver.Rows + err = conn.Raw(func(d any) error { + var queryErr error + driverRows, queryErr = d.(driver.QueryerContext).QueryContext(context.Background(), query, nil) + return queryErr + }) + require.NoError(t, err) + defer driverRows.Close() //nolint:errcheck + + batches, err := driverRows.(dbsqlrows.Rows).GetArrowBatches(context.Background()) + require.NoError(t, err) + defer batches.Close() + + var rowCount int64 + for { + record, nextErr := batches.Next() + if nextErr == io.EOF { + break + } + require.NoError(t, nextErr) + rowCount += record.NumRows() + record.Release() + } + + require.Equal(t, int64(wantRows), rowCount, "CloudFetch must surface exactly the requested rows, with no Arrow padding") +} diff --git a/internal/rows/arrowbased/batchloader.go b/internal/rows/arrowbased/batchloader.go index 889524d3..0a8a248e 100644 --- a/internal/rows/arrowbased/batchloader.go +++ b/internal/rows/arrowbased/batchloader.go @@ -31,6 +31,27 @@ type IPCStreamIterator interface { Close() } +// positionedIPCStreamIterator is an optional extension of IPCStreamIterator for +// streams that carry server-declared positioning metadata alongside each IPC +// payload. It is implemented ONLY by the CloudFetch iterator: CloudFetch result +// links carry an authoritative StartRowOffset and RowCount, and the Arrow IPC +// files they point at may be padded with extra rows beyond RowCount. batchIterator +// uses this metadata to (a) anchor each batch at its true stream offset and +// (b) cap the decoded records to RowCount so padding rows are not surfaced as +// real data (see limitArrowRecords). +// +// The inline/local Arrow path intentionally does NOT implement this: those +// batches are returned verbatim by the server with no padding, and their +// per-batch RowCount has historically been untrusted, so capping there would +// risk silently dropping rows. NextWithMetadata returns expectedRows < 0 to mean +// "row count unknown — do not cap". +type positionedIPCStreamIterator interface { + // NextWithMetadata returns the next IPC payload along with its absolute + // stream start offset and the server-declared row count (expectedRows). An + // expectedRows < 0 means the count is unknown and no capping should occur. + NextWithMetadata() (reader io.Reader, startRowOffset int64, expectedRows int64, err error) +} + func NewCloudIPCStreamIterator( ctx context.Context, files []*cli_service.TSparkArrowResultLink, @@ -174,8 +195,17 @@ type cloudIPCStreamIterator struct { } var _ IPCStreamIterator = (*cloudIPCStreamIterator)(nil) +var _ positionedIPCStreamIterator = (*cloudIPCStreamIterator)(nil) func (bi *cloudIPCStreamIterator) Next() (io.Reader, error) { + reader, _, _, err := bi.NextWithMetadata() + return reader, err +} + +// NextWithMetadata returns the next downloaded CloudFetch IPC payload together +// with the link's authoritative StartRowOffset and RowCount. The Arrow file may +// contain padding rows beyond RowCount; the caller caps to RowCount. +func (bi *cloudIPCStreamIterator) NextWithMetadata() (io.Reader, int64, int64, error) { for (bi.downloadTasks.Len() < bi.cfg.MaxDownloadThreads) && (bi.pendingLinks.Len() > 0) { link := bi.pendingLinks.Dequeue() logger.Debug().Msgf( @@ -204,7 +234,7 @@ func (bi *cloudIPCStreamIterator) Next() (io.Reader, error) { task := bi.downloadTasks.Dequeue() if task == nil { - return nil, io.EOF + return nil, 0, 0, io.EOF } data, downloadMs, err := task.GetResult() @@ -212,7 +242,7 @@ func (bi *cloudIPCStreamIterator) Next() (io.Reader, error) { // once we've got an errored out task - cancel the remaining ones if err != nil { bi.Close() - return nil, err + return nil, 0, 0, err } // explicitly call cancel function on successfully completed task to avoid context leak @@ -226,7 +256,7 @@ func (bi *cloudIPCStreamIterator) Next() (io.Reader, error) { bi.onFileDownloaded(downloadMs) } - return data, nil + return data, task.link.StartRowOffset, task.link.RowCount, nil } func (bi *cloudIPCStreamIterator) HasNext() bool { @@ -558,16 +588,38 @@ func NewBatchIterator(ipcIterator IPCStreamIterator, startRowOffset int64) Batch } func (bi *batchIterator) Next() (SparkArrowBatch, error) { - reader, err := bi.ipcIterator.Next() + // startRowOffset is the absolute offset of this batch within the result + // stream. For positioned (CloudFetch) streams it comes from the server's + // link metadata; otherwise we track it locally by accumulating decoded rows. + startRowOffset := bi.startRowOffset + // expectedRows is the server-declared row count for this batch. A value < 0 + // means "unknown" and disables capping (the inline/local path). + expectedRows := int64(-1) + var reader io.Reader + var err error + if positionedIterator, ok := bi.ipcIterator.(positionedIPCStreamIterator); ok { + reader, startRowOffset, expectedRows, err = positionedIterator.NextWithMetadata() + } else { + reader, err = bi.ipcIterator.Next() + } if err != nil { return nil, err } - records, err := getArrowRecords(reader, bi.startRowOffset) + records, err := getArrowRecords(reader, startRowOffset) if err != nil { return nil, err } + // Cap the decoded records to the server-declared row count, dropping the + // padding rows some CloudFetch Arrow files carry beyond their link's + // RowCount. Only cap when the count is strictly positive: expectedRows == 0 + // with decoded rows is treated as "untrustworthy, do not cap" rather than + // silently dropping the whole batch (see #371 review F1). + if expectedRows > 0 { + records = limitArrowRecords(records, expectedRows) + } + // When using CloudFetch, cached Arrow IPC files may contain stale column // names from a previous query. Replace the embedded schema with the // authoritative schema from GetResultSetMetadata. @@ -593,14 +645,62 @@ func (bi *batchIterator) Next() (SparkArrowBatch, error) { } batch := &sparkArrowBatch{ - Delimiter: rowscanner.NewDelimiter(bi.startRowOffset, totalRows), + Delimiter: rowscanner.NewDelimiter(startRowOffset, totalRows), arrowRecords: records, } - bi.startRowOffset += totalRows + // Advance the local offset for the next non-positioned batch. Positioned + // streams overwrite startRowOffset from server metadata on the next call. + bi.startRowOffset = startRowOffset + totalRows return batch, nil } +// limitArrowRecords caps a decoded batch to expectedRows, releasing any records +// (and the tail of a partially-kept record) that fall beyond the server-declared +// count. It is the mechanism that strips CloudFetch Arrow padding rows. +// +// Contract: +// - Callers must only invoke this when expectedRows is trustworthy and the +// batch may be over-long; expectedRows < 0 is treated as "unknown" and the +// records are returned unchanged. +// - When a record straddles the boundary it is sliced with NewSlice(0, remaining): +// the slice bounds are record-relative (0-based within the record), while the +// Delimiter's start is the ABSOLUTE stream offset of the record. Keep these two +// distinct — do not pass the absolute start as a slice index. +func limitArrowRecords(records []SparkArrowRecord, expectedRows int64) []SparkArrowRecord { + if expectedRows < 0 { + return records + } + + remaining := expectedRows + limited := records[:0] + for _, record := range records { + if remaining <= 0 { + record.Release() + continue + } + + if record.NumRows() <= remaining { + limited = append(limited, record) + remaining -= record.NumRows() + continue + } + + start := record.Start() + sliced := record.NewSlice(0, remaining) + record.Release() + if sliced != nil { + limited = append(limited, &sparkArrowRecord{ + Delimiter: rowscanner.NewDelimiter(start, sliced.NumRows()), + Record: sliced, + }) + } + remaining = 0 + } + + return limited +} + func (bi *batchIterator) HasNext() bool { return bi.ipcIterator.HasNext() } diff --git a/internal/rows/arrowbased/batchloader_test.go b/internal/rows/arrowbased/batchloader_test.go index d66ded36..3e3e4831 100644 --- a/internal/rows/arrowbased/batchloader_test.go +++ b/internal/rows/arrowbased/batchloader_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "fmt" + "io" "net/http" "net/http/httptest" "runtime" @@ -1160,3 +1161,104 @@ func countDownloadTaskGoroutines() int { } return strings.Count(string(buf), "cloudFetchDownloadTask).Run") } + +// fakePositionedIPCIterator is a test IPCStreamIterator that also implements +// positionedIPCStreamIterator, so it exercises the CloudFetch row-count-capping +// path through NewBatchIterator without real CloudFetch downloads. +type fakePositionedIPCIterator struct { + data []byte + startRowOffset int64 + expectedRows int64 + consumed bool +} + +var _ IPCStreamIterator = (*fakePositionedIPCIterator)(nil) +var _ positionedIPCStreamIterator = (*fakePositionedIPCIterator)(nil) + +func (f *fakePositionedIPCIterator) Next() (io.Reader, error) { + r, _, _, err := f.NextWithMetadata() + return r, err +} +func (f *fakePositionedIPCIterator) NextWithMetadata() (io.Reader, int64, int64, error) { + if f.consumed { + return nil, 0, 0, io.EOF + } + f.consumed = true + return bytes.NewReader(f.data), f.startRowOffset, f.expectedRows, nil +} +func (f *fakePositionedIPCIterator) HasNext() bool { return !f.consumed } +func (f *fakePositionedIPCIterator) Close() {} + +// fakePlainIPCIterator implements only IPCStreamIterator (the inline/local +// shape) so the cap must never apply to it. +type fakePlainIPCIterator struct { + data []byte + consumed bool +} + +var _ IPCStreamIterator = (*fakePlainIPCIterator)(nil) + +func (f *fakePlainIPCIterator) Next() (io.Reader, error) { + if f.consumed { + return nil, io.EOF + } + f.consumed = true + return bytes.NewReader(f.data), nil +} +func (f *fakePlainIPCIterator) HasNext() bool { return !f.consumed } +func (f *fakePlainIPCIterator) Close() {} + +// TestBatchIterator_RowCountCap covers the batchIterator -> limitArrowRecords +// integration (#371 review F1/F6). generateMockArrowBytes writes the 3-row +// record twice, so each stream decodes to 6 rows. +func TestBatchIterator_RowCountCap(t *testing.T) { + const decoded = 6 + const startOffset int64 = 100 + + t.Run("positioned: caps padding rows down to RowCount", func(t *testing.T) { + it := &fakePositionedIPCIterator{data: generateMockArrowBytes(generateArrowRecord()), startRowOffset: startOffset, expectedRows: 4} + bi := NewBatchIterator(it, startOffset) + batch, err := bi.Next() + assert.NoError(t, err) + defer batch.Close() + assert.Equal(t, int64(4), batch.Count(), "batch should be capped to RowCount") + assert.Equal(t, startOffset, batch.Start(), "batch must anchor at the server offset") + }) + + t.Run("positioned: exact boundary keeps all rows", func(t *testing.T) { + it := &fakePositionedIPCIterator{data: generateMockArrowBytes(generateArrowRecord()), startRowOffset: startOffset, expectedRows: decoded} + bi := NewBatchIterator(it, startOffset) + batch, err := bi.Next() + assert.NoError(t, err) + defer batch.Close() + assert.Equal(t, int64(decoded), batch.Count()) + }) + + t.Run("positioned: RowCount larger than decoded keeps all rows", func(t *testing.T) { + it := &fakePositionedIPCIterator{data: generateMockArrowBytes(generateArrowRecord()), startRowOffset: startOffset, expectedRows: 100} + bi := NewBatchIterator(it, startOffset) + batch, err := bi.Next() + assert.NoError(t, err) + defer batch.Close() + assert.Equal(t, int64(decoded), batch.Count()) + }) + + t.Run("positioned: RowCount==0 is NOT trusted, keeps all rows (F1)", func(t *testing.T) { + it := &fakePositionedIPCIterator{data: generateMockArrowBytes(generateArrowRecord()), startRowOffset: startOffset, expectedRows: 0} + bi := NewBatchIterator(it, startOffset) + batch, err := bi.Next() + assert.NoError(t, err) + defer batch.Close() + assert.Equal(t, int64(decoded), batch.Count(), "RowCount==0 must not silently drop the batch") + }) + + t.Run("plain/inline iterator is never capped", func(t *testing.T) { + it := &fakePlainIPCIterator{data: generateMockArrowBytes(generateArrowRecord())} + bi := NewBatchIterator(it, startOffset) + batch, err := bi.Next() + assert.NoError(t, err) + defer batch.Close() + assert.Equal(t, int64(decoded), batch.Count(), "inline path must return all decoded rows") + assert.Equal(t, startOffset, batch.Start()) + }) +}