Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions driver_e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -657,3 +657,71 @@ func getServer(state *callState) *httptest.Server {
},
})
}

// TestE2ECloudFetchExactRowCount validates that a large CloudFetch result drains
// the EXACT number of rows requested. CloudFetch Arrow IPC files can carry padding
// rows beyond a link's server-declared RowCount; without capping to RowCount the
// driver over-reports (e.g. 301,407 rows for a LIMIT 300000). This is the
// regression guard for the row-count cap. Skipped in -short mode because it
// drains a multi-million-row result over several CloudFetch link pages.
func TestE2ECloudFetchExactRowCount(t *testing.T) {
if testing.Short() {
t.Skip("skipping large CloudFetch drain in -short mode")
}
host := os.Getenv("DATABRICKS_PECOTESTING_SERVER_HOSTNAME")
httpPath := os.Getenv("DATABRICKS_PECOTESTING_HTTP_PATH2")
token := os.Getenv("DATABRICKS_PECOTESTING_TOKEN")
if token == "" {
token = os.Getenv("DATABRICKS_PECOTESTING_TOKEN_PERSONAL")
}
if host == "" || httpPath == "" || token == "" {
t.Skip("set DATABRICKS_PECOTESTING_SERVER_HOSTNAME, DATABRICKS_PECOTESTING_HTTP_PATH2, and DATABRICKS_PECOTESTING_TOKEN to run")
}

const wantRows = 2000000

connector, err := NewConnector(
WithServerHostname(host),
WithPort(443),
WithHTTPPath(httpPath),
WithAccessToken(token),
WithMaxRows(500000),
)
require.NoError(t, err)

db := sql.OpenDB(connector)
defer db.Close() //nolint:errcheck

conn, err := db.Conn(context.Background())
require.NoError(t, err)
defer conn.Close() //nolint:errcheck

// A wide-ish row (id + 64-byte pad) over 2M rows forces a multi-page
// CloudFetch (URL-based) result rather than inline Arrow.
query := fmt.Sprintf("SELECT id, repeat('x', 64) AS pad FROM range(%d)", wantRows)
var driverRows driver.Rows
err = conn.Raw(func(d any) error {
var queryErr error
driverRows, queryErr = d.(driver.QueryerContext).QueryContext(context.Background(), query, nil)
return queryErr
})
require.NoError(t, err)
defer driverRows.Close() //nolint:errcheck

batches, err := driverRows.(dbsqlrows.Rows).GetArrowBatches(context.Background())
require.NoError(t, err)
defer batches.Close()

var rowCount int64
for {
record, nextErr := batches.Next()
if nextErr == io.EOF {
break
}
require.NoError(t, nextErr)
rowCount += record.NumRows()
record.Release()
}

require.Equal(t, int64(wantRows), rowCount, "CloudFetch must surface exactly the requested rows, with no Arrow padding")
}
114 changes: 107 additions & 7 deletions internal/rows/arrowbased/batchloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,27 @@ type IPCStreamIterator interface {
Close()
}

// positionedIPCStreamIterator is an optional extension of IPCStreamIterator for
// streams that carry server-declared positioning metadata alongside each IPC
// payload. It is implemented ONLY by the CloudFetch iterator: CloudFetch result
// links carry an authoritative StartRowOffset and RowCount, and the Arrow IPC
// files they point at may be padded with extra rows beyond RowCount. batchIterator
// uses this metadata to (a) anchor each batch at its true stream offset and
// (b) cap the decoded records to RowCount so padding rows are not surfaced as
// real data (see limitArrowRecords).
//
// The inline/local Arrow path intentionally does NOT implement this: those
// batches are returned verbatim by the server with no padding, and their
// per-batch RowCount has historically been untrusted, so capping there would
// risk silently dropping rows. NextWithMetadata returns expectedRows < 0 to mean
// "row count unknown — do not cap".
type positionedIPCStreamIterator interface {
// NextWithMetadata returns the next IPC payload along with its absolute
// stream start offset and the server-declared row count (expectedRows). An
// expectedRows < 0 means the count is unknown and no capping should occur.
NextWithMetadata() (reader io.Reader, startRowOffset int64, expectedRows int64, err error)
}

func NewCloudIPCStreamIterator(
ctx context.Context,
files []*cli_service.TSparkArrowResultLink,
Expand Down Expand Up @@ -174,8 +195,17 @@ type cloudIPCStreamIterator struct {
}

var _ IPCStreamIterator = (*cloudIPCStreamIterator)(nil)
var _ positionedIPCStreamIterator = (*cloudIPCStreamIterator)(nil)

func (bi *cloudIPCStreamIterator) Next() (io.Reader, error) {
reader, _, _, err := bi.NextWithMetadata()
return reader, err
}

// NextWithMetadata returns the next downloaded CloudFetch IPC payload together
// with the link's authoritative StartRowOffset and RowCount. The Arrow file may
// contain padding rows beyond RowCount; the caller caps to RowCount.
func (bi *cloudIPCStreamIterator) NextWithMetadata() (io.Reader, int64, int64, error) {
for (bi.downloadTasks.Len() < bi.cfg.MaxDownloadThreads) && (bi.pendingLinks.Len() > 0) {
link := bi.pendingLinks.Dequeue()
logger.Debug().Msgf(
Expand Down Expand Up @@ -204,15 +234,15 @@ func (bi *cloudIPCStreamIterator) Next() (io.Reader, error) {

task := bi.downloadTasks.Dequeue()
if task == nil {
return nil, io.EOF
return nil, 0, 0, io.EOF
}

data, downloadMs, err := task.GetResult()

// once we've got an errored out task - cancel the remaining ones
if err != nil {
bi.Close()
return nil, err
return nil, 0, 0, err
}

// explicitly call cancel function on successfully completed task to avoid context leak
Expand All @@ -226,7 +256,7 @@ func (bi *cloudIPCStreamIterator) Next() (io.Reader, error) {
bi.onFileDownloaded(downloadMs)
}

return data, nil
return data, task.link.StartRowOffset, task.link.RowCount, nil
}

func (bi *cloudIPCStreamIterator) HasNext() bool {
Expand Down Expand Up @@ -558,16 +588,38 @@ func NewBatchIterator(ipcIterator IPCStreamIterator, startRowOffset int64) Batch
}

func (bi *batchIterator) Next() (SparkArrowBatch, error) {
reader, err := bi.ipcIterator.Next()
// startRowOffset is the absolute offset of this batch within the result
// stream. For positioned (CloudFetch) streams it comes from the server's
// link metadata; otherwise we track it locally by accumulating decoded rows.
startRowOffset := bi.startRowOffset
// expectedRows is the server-declared row count for this batch. A value < 0
// means "unknown" and disables capping (the inline/local path).
expectedRows := int64(-1)
var reader io.Reader
var err error
if positionedIterator, ok := bi.ipcIterator.(positionedIPCStreamIterator); ok {
reader, startRowOffset, expectedRows, err = positionedIterator.NextWithMetadata()
} else {
reader, err = bi.ipcIterator.Next()
}
if err != nil {
return nil, err
}

records, err := getArrowRecords(reader, bi.startRowOffset)
records, err := getArrowRecords(reader, startRowOffset)
if err != nil {
return nil, err
}

// Cap the decoded records to the server-declared row count, dropping the
// padding rows some CloudFetch Arrow files carry beyond their link's
// RowCount. Only cap when the count is strictly positive: expectedRows == 0
// with decoded rows is treated as "untrustworthy, do not cap" rather than
// silently dropping the whole batch (see #371 review F1).
if expectedRows > 0 {
records = limitArrowRecords(records, expectedRows)
}

// When using CloudFetch, cached Arrow IPC files may contain stale column
// names from a previous query. Replace the embedded schema with the
// authoritative schema from GetResultSetMetadata.
Expand All @@ -593,14 +645,62 @@ func (bi *batchIterator) Next() (SparkArrowBatch, error) {
}

batch := &sparkArrowBatch{
Delimiter: rowscanner.NewDelimiter(bi.startRowOffset, totalRows),
Delimiter: rowscanner.NewDelimiter(startRowOffset, totalRows),
arrowRecords: records,
}

bi.startRowOffset += totalRows
// Advance the local offset for the next non-positioned batch. Positioned
// streams overwrite startRowOffset from server metadata on the next call.
bi.startRowOffset = startRowOffset + totalRows
return batch, nil
}

// limitArrowRecords caps a decoded batch to expectedRows, releasing any records
// (and the tail of a partially-kept record) that fall beyond the server-declared
// count. It is the mechanism that strips CloudFetch Arrow padding rows.
//
// Contract:
// - Callers must only invoke this when expectedRows is trustworthy and the
// batch may be over-long; expectedRows < 0 is treated as "unknown" and the
// records are returned unchanged.
// - When a record straddles the boundary it is sliced with NewSlice(0, remaining):
// the slice bounds are record-relative (0-based within the record), while the
// Delimiter's start is the ABSOLUTE stream offset of the record. Keep these two
// distinct — do not pass the absolute start as a slice index.
func limitArrowRecords(records []SparkArrowRecord, expectedRows int64) []SparkArrowRecord {
if expectedRows < 0 {
return records
}

remaining := expectedRows
limited := records[:0]
for _, record := range records {
if remaining <= 0 {
record.Release()
continue
}

if record.NumRows() <= remaining {
limited = append(limited, record)
remaining -= record.NumRows()
continue
}

start := record.Start()
sliced := record.NewSlice(0, remaining)
record.Release()
if sliced != nil {
limited = append(limited, &sparkArrowRecord{
Delimiter: rowscanner.NewDelimiter(start, sliced.NumRows()),
Record: sliced,
})
}
remaining = 0
}

return limited
}

func (bi *batchIterator) HasNext() bool {
return bi.ipcIterator.HasNext()
}
Expand Down
102 changes: 102 additions & 0 deletions internal/rows/arrowbased/batchloader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"runtime"
Expand Down Expand Up @@ -1160,3 +1161,104 @@ func countDownloadTaskGoroutines() int {
}
return strings.Count(string(buf), "cloudFetchDownloadTask).Run")
}

// fakePositionedIPCIterator is a test IPCStreamIterator that also implements
// positionedIPCStreamIterator, so it exercises the CloudFetch row-count-capping
// path through NewBatchIterator without real CloudFetch downloads.
type fakePositionedIPCIterator struct {
data []byte
startRowOffset int64
expectedRows int64
consumed bool
}

var _ IPCStreamIterator = (*fakePositionedIPCIterator)(nil)
var _ positionedIPCStreamIterator = (*fakePositionedIPCIterator)(nil)

func (f *fakePositionedIPCIterator) Next() (io.Reader, error) {
r, _, _, err := f.NextWithMetadata()
return r, err
}
func (f *fakePositionedIPCIterator) NextWithMetadata() (io.Reader, int64, int64, error) {
if f.consumed {
return nil, 0, 0, io.EOF
}
f.consumed = true
return bytes.NewReader(f.data), f.startRowOffset, f.expectedRows, nil
}
func (f *fakePositionedIPCIterator) HasNext() bool { return !f.consumed }
func (f *fakePositionedIPCIterator) Close() {}

// fakePlainIPCIterator implements only IPCStreamIterator (the inline/local
// shape) so the cap must never apply to it.
type fakePlainIPCIterator struct {
data []byte
consumed bool
}

var _ IPCStreamIterator = (*fakePlainIPCIterator)(nil)

func (f *fakePlainIPCIterator) Next() (io.Reader, error) {
if f.consumed {
return nil, io.EOF
}
f.consumed = true
return bytes.NewReader(f.data), nil
}
func (f *fakePlainIPCIterator) HasNext() bool { return !f.consumed }
func (f *fakePlainIPCIterator) Close() {}

// TestBatchIterator_RowCountCap covers the batchIterator -> limitArrowRecords
// integration (#371 review F1/F6). generateMockArrowBytes writes the 3-row
// record twice, so each stream decodes to 6 rows.
func TestBatchIterator_RowCountCap(t *testing.T) {
const decoded = 6
const startOffset int64 = 100

t.Run("positioned: caps padding rows down to RowCount", func(t *testing.T) {
it := &fakePositionedIPCIterator{data: generateMockArrowBytes(generateArrowRecord()), startRowOffset: startOffset, expectedRows: 4}
bi := NewBatchIterator(it, startOffset)
batch, err := bi.Next()
assert.NoError(t, err)
defer batch.Close()
assert.Equal(t, int64(4), batch.Count(), "batch should be capped to RowCount")
assert.Equal(t, startOffset, batch.Start(), "batch must anchor at the server offset")
})

t.Run("positioned: exact boundary keeps all rows", func(t *testing.T) {
it := &fakePositionedIPCIterator{data: generateMockArrowBytes(generateArrowRecord()), startRowOffset: startOffset, expectedRows: decoded}
bi := NewBatchIterator(it, startOffset)
batch, err := bi.Next()
assert.NoError(t, err)
defer batch.Close()
assert.Equal(t, int64(decoded), batch.Count())
})

t.Run("positioned: RowCount larger than decoded keeps all rows", func(t *testing.T) {
it := &fakePositionedIPCIterator{data: generateMockArrowBytes(generateArrowRecord()), startRowOffset: startOffset, expectedRows: 100}
bi := NewBatchIterator(it, startOffset)
batch, err := bi.Next()
assert.NoError(t, err)
defer batch.Close()
assert.Equal(t, int64(decoded), batch.Count())
})

t.Run("positioned: RowCount==0 is NOT trusted, keeps all rows (F1)", func(t *testing.T) {
it := &fakePositionedIPCIterator{data: generateMockArrowBytes(generateArrowRecord()), startRowOffset: startOffset, expectedRows: 0}
bi := NewBatchIterator(it, startOffset)
batch, err := bi.Next()
assert.NoError(t, err)
defer batch.Close()
assert.Equal(t, int64(decoded), batch.Count(), "RowCount==0 must not silently drop the batch")
})

t.Run("plain/inline iterator is never capped", func(t *testing.T) {
it := &fakePlainIPCIterator{data: generateMockArrowBytes(generateArrowRecord())}
bi := NewBatchIterator(it, startOffset)
batch, err := bi.Next()
assert.NoError(t, err)
defer batch.Close()
assert.Equal(t, int64(decoded), batch.Count(), "inline path must return all decoded rows")
assert.Equal(t, startOffset, batch.Start())
})
}
Loading