diff --git a/connector.go b/connector.go index 467adddc..2427095d 100644 --- a/connector.go +++ b/connector.go @@ -409,6 +409,18 @@ func WithCloudFetch(useCloudFetch bool) ConnOption { } } +// WithUseArrowNativeDecimal enables native Arrow decimal support. Default is false. +// When enabled, DECIMAL columns retrieved over Arrow batches are decoded as +// native Arrow decimal128 values and returned as a lossless decimal string +// (e.g. "123.45"), preserving full precision and scale. When disabled (the +// default), decimals are returned as strings via the non-native path, so the +// observable Go type is the same either way. +func WithUseArrowNativeDecimal(enable bool) ConnOption { + return func(c *config.Config) { + c.UseArrowNativeDecimal = enable + } +} + // WithMaxDownloadThreads sets up maximum download threads for cloud fetch. Default is 10. func WithMaxDownloadThreads(numThreads int) ConnOption { return func(c *config.Config) { diff --git a/doc.go b/doc.go index 5a3cf57c..a0ce466c 100644 --- a/doc.go +++ b/doc.go @@ -363,7 +363,7 @@ DATE --> time.Time TIMESTAMP --> time.Time -DECIMAL(p,s) --> sql.RawBytes +DECIMAL(p,s) --> string BINARY --> sql.RawBytes diff --git a/internal/rows/arrowbased/arrowRows.go b/internal/rows/arrowbased/arrowRows.go index 857ca368..f2c04f41 100644 --- a/internal/rows/arrowbased/arrowRows.go +++ b/internal/rows/arrowbased/arrowRows.go @@ -218,10 +218,11 @@ func (ars *arrowRowScanner) ScanRow( col := ars.colInfo[i] dbType := col.dbType - if (dbType == cli_service.TTypeId_DECIMAL_TYPE && ars.UseArrowNativeDecimal) || - (isIntervalType(dbType) && ars.UseArrowNativeIntervalTypes) { - // not yet fully supported - ars.Error().Msgf(errArrowRowsUnsupportedNativeType(dbType.String())) + // Decimal types are supported natively (returned as a lossless string + // via decimal128Container). Only interval types remain unsupported. + if isIntervalType(dbType) && ars.UseArrowNativeIntervalTypes { + // interval types are not yet fully supported + ars.Error().Msg(errArrowRowsUnsupportedNativeType(dbType.String())) return dbsqlerrint.NewDriverError(ars.ctx, errArrowRowsUnsupportedNativeType(dbType.String()), nil) } diff --git a/internal/rows/arrowbased/arrowRows_test.go b/internal/rows/arrowbased/arrowRows_test.go index 705bfec7..6d38eb44 100644 --- a/internal/rows/arrowbased/arrowRows_test.go +++ b/internal/rows/arrowbased/arrowRows_test.go @@ -887,7 +887,10 @@ func TestArrowRowScanner(t *testing.T) { err := ars.ScanRow(dest, 0) - if i < 3 { + // Columns are ordered: 0=array, 1=map, 2=struct, 3=decimal, + // 4=interval_ym, 5=interval_dt. Complex types (0-2) and decimal (3) + // are supported natively; only the interval types (4-5) still error. + if i < 4 { assert.Nil(t, err) } else { assert.NotNil(t, err) @@ -1274,7 +1277,7 @@ func TestArrowRowScanner(t *testing.T) { "[[1,2,3],[4,5,6],null]", "[{\"key1\":1,\"key2\":2},{\"key3\":3,\"key4\":4},null]", "[{\"Field1\":77,\"Field2\":\"2020-12-31 00:00:00 +0000 UTC\"},{\"Field1\":13,\"Field2\":\"-2020-12-31 00:00:00 +0000 UTC\"},{\"Field1\":null,\"Field2\":null}]", - "[5.15,123.45,null]", + "[\"5.15\",\"123.45\",null]", "[\"2020-12-31 00:00:00 +0000 UTC\",\"-2020-12-31 00:00:00 +0000 UTC\",null]", } diff --git a/internal/rows/arrowbased/columnValues.go b/internal/rows/arrowbased/columnValues.go index 47105095..d3bfbd2e 100644 --- a/internal/rows/arrowbased/columnValues.go +++ b/internal/rows/arrowbased/columnValues.go @@ -521,9 +521,17 @@ type decimal128Container struct { var _ columnValues = (*decimal128Container)(nil) func (tvc *decimal128Container) Value(i int) (any, error) { + if tvc.decimalArray.IsNull(i) { + return nil, nil + } dv := tvc.decimalArray.Value(i) - fv := dv.ToFloat64(tvc.scale) - return fv, nil + // Return the decimal as a lossless string. float64 cannot exactly + // represent high-precision/high-scale decimals, so converting here would + // silently lose precision for the very type users reach for to avoid that. + // Returning a string also matches the non-native default path and the + // column-based transport, so the Go type is consistent regardless of how + // the server chooses to send the result. + return dv.ToString(tvc.scale), nil } func (tvc *decimal128Container) IsNull(i int) bool { diff --git a/internal/rows/arrowbased/columnValues_test.go b/internal/rows/arrowbased/columnValues_test.go new file mode 100644 index 00000000..de3c0d72 --- /dev/null +++ b/internal/rows/arrowbased/columnValues_test.go @@ -0,0 +1,116 @@ +package arrowbased + +import ( + "testing" + + "github.com/apache/arrow/go/v12/arrow" + "github.com/apache/arrow/go/v12/arrow/array" + "github.com/apache/arrow/go/v12/arrow/decimal128" + "github.com/apache/arrow/go/v12/arrow/memory" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestDecimal128ContainerValue verifies that the native Arrow decimal path +// decodes DECIMAL columns losslessly as strings, preserving full precision and +// scale, and that nulls are surfaced as nil. This is the value-level coverage +// for the path exercised when UseArrowNativeDecimal is enabled (the boundary +// test in TestArrowRowScanner only proves ScanRow no longer blocks decimals). +func TestDecimal128ContainerValue(t *testing.T) { + const precision, scale int32 = 38, 18 + + // A 38-digit, scale-18 value that float64 (≈15-17 significant digits) + // cannot represent exactly — the regression guard for lossy conversion. + const highPrecision = "12345678901234567890.123456789012345678" + + // Note: ToString(scale) always renders to the column's full declared scale + // (18 here), so values are normalized — e.g. "3.30" -> "3.300000000000000000". + // This is lossless and deterministic, unlike a float64 conversion. + cases := []struct { + name string + input string // empty string => null + expected any + }{ + {name: "high precision preserved", input: highPrecision, expected: highPrecision}, + {name: "simple value normalized to scale", input: "3.30", expected: "3.300000000000000000"}, + {name: "negative value", input: "-0.000000000000000001", expected: "-0.000000000000000001"}, + {name: "zero", input: "0.000000000000000000", expected: "0.000000000000000000"}, + {name: "null", input: "", expected: nil}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + mem := memory.NewGoAllocator() + dt := &arrow.Decimal128Type{Precision: precision, Scale: scale} + b := array.NewDecimal128Builder(mem, dt) + defer b.Release() + + if tc.input == "" { + b.AppendNull() + } else { + num, err := decimal128.FromString(tc.input, precision, scale) + require.NoError(t, err) + b.Append(num) + } + + arr := b.NewDecimal128Array() + defer arr.Release() + + container := &decimal128Container{scale: scale} + err := container.SetValueArray(arr.Data()) + require.NoError(t, err) + + assert.Equal(t, tc.input == "", container.IsNull(0)) + + got, err := container.Value(0) + require.NoError(t, err) + assert.Equal(t, tc.expected, got) + }) + } +} + +// TestDecimal128NestedInListPreservesPrecision guards the DEFAULT path: complex +// types are native by default (UseArrowNativeComplexTypes=true), so a decimal +// nested inside an ARRAY/STRUCT is decoded by decimal128Container even when the +// top-level UseArrowNativeDecimal flag is off. Before the lossless-string fix +// this path went through ToFloat64 and silently corrupted high-precision values +// (and rendered SQL NULL as the number 0). This test pins the lossless behavior. +func TestDecimal128NestedInListPreservesPrecision(t *testing.T) { + const precision, scale int32 = 38, 18 + const highPrecision = "12345678901234567890.123456789012345678" + + mem := memory.NewGoAllocator() + elemType := &arrow.Decimal128Type{Precision: precision, Scale: scale} + lb := array.NewListBuilder(mem, elemType) + defer lb.Release() + vb := lb.ValueBuilder().(*array.Decimal128Builder) + + // One list: [highPrecision, NULL, "3.300000000000000000"] + lb.Append(true) + for _, s := range []string{highPrecision, "", "3.30"} { + if s == "" { + vb.AppendNull() + continue + } + num, err := decimal128.FromString(s, precision, scale) + require.NoError(t, err) + vb.Append(num) + } + + listArr := lb.NewListArray() + defer listArr.Release() + + lvc := &listValueContainer{ + listArray: listArr, + listArrayType: arrow.ListOf(elemType), + values: &decimal128Container{scale: scale}, + } + require.NoError(t, lvc.values.SetValueArray(listArr.ListValues().Data())) + + got, err := lvc.Value(0) + require.NoError(t, err) + + // Decimals are rendered as lossless JSON strings; NULL stays null (not 0). + expected := `["12345678901234567890.123456789012345678",null,"3.300000000000000000"]` + assert.Equal(t, expected, got) +}