From e2b5e5cf7326ec7e10c457f5a7b7b42399c01f3b Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Tue, 5 May 2026 21:24:15 +0800 Subject: [PATCH 01/24] feat: enhance schema alignment and recursion handling - Added `align_plan_to_schema` and `SchemaAlignExec` for improved schema alignment in execution plans. - Maintained strict behavior in `project_plan_to_schema` for projection-only cases. - Updated adapter to handle nullability narrowing while preserving SQL behavior. - Modified `RecursiveQueryExec` to preserve static/declared schema and aligned recursive term at plan construction. - Removed nullability-widening schema synthesis for cleaner execution. - Restored `0 AS` level in SQL logic test file `cte.slt`. --- datafusion/physical-plan/src/common.rs | 280 +++++++++++++++--- .../physical-plan/src/recursive_query.rs | 59 ++-- datafusion/sqllogictest/test_files/cte.slt | 2 +- 3 files changed, 269 insertions(+), 72 deletions(-) diff --git a/datafusion/physical-plan/src/common.rs b/datafusion/physical-plan/src/common.rs index 0dafcf6bd3390..000d41d03ebb5 100644 --- a/datafusion/physical-plan/src/common.rs +++ b/datafusion/physical-plan/src/common.rs @@ -24,15 +24,21 @@ use std::sync::Arc; use super::SendableRecordBatchStream; use crate::expressions::{CastExpr, Column}; use crate::projection::{ProjectionExec, ProjectionExpr}; -use crate::stream::RecordBatchReceiverStream; -use crate::{ColumnStatistics, ExecutionPlan, Statistics}; +use crate::stream::{RecordBatchReceiverStream, RecordBatchStreamAdapter}; +use crate::{ + ColumnStatistics, DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, + PlanProperties, Statistics, +}; use arrow::array::Array; use arrow::datatypes::{Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_common::stats::Precision; +use datafusion_common::tree_node::TreeNodeRecursion; use datafusion_common::{Result, plan_err}; +use datafusion_execution::TaskContext; use datafusion_execution::memory_pool::MemoryReservation; +use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr}; use futures::{StreamExt, TryStreamExt}; @@ -94,57 +100,70 @@ fn build_file_list_recurse( /// /// This helper is intended for operators that combine independently planned children but /// expose a single declared output schema. It returns `input` unchanged when schemas already -/// match exactly. Otherwise, it validates that projection can safely produce the expected -/// schema, then wraps `input` in a [`ProjectionExec`] that keeps columns in their existing -/// positional order and aliases them to `expected_schema`'s field names. +/// match exactly. Otherwise, it validates positional compatibility and uses a plan-time +/// adapter whose advertised and emitted schema is exactly `expected_schema`. /// -/// [`ProjectionExec`] can rename fields. When the expected field is nullable and the input -/// field is not, this helper also widens nullability with a same-type [`CastExpr`]. It rejects -/// differences that projection cannot safely normalize exactly, such as data type, metadata, -/// schema metadata, and nullability narrowing. -pub fn project_plan_to_schema( +/// Prefer this helper over rebinding batches inside a parent operator's stream. The alignment +/// is visible in the physical plan, while batch schema rebinding remains contained in the +/// adapter as the implementation detail required to uphold the plan-level schema contract. +/// +/// This helper can align field names and nullability to the declared schema. It rejects +/// differences that would change values or silently lose schema information, such as column +/// count, data type, field metadata, or schema metadata mismatches. +pub fn align_plan_to_schema( input: Arc, expected_schema: &SchemaRef, ) -> Result> { - let input_schema = input.schema(); - if input_schema.as_ref() == expected_schema.as_ref() { + validate_schema_alignment(&input.schema(), expected_schema, "align")?; + + if input.schema().as_ref() == expected_schema.as_ref() { return Ok(input); } - if input_schema.fields().len() != expected_schema.fields().len() { - return plan_err!( - "Cannot project plan to expected schema: expected {} column(s), got {}", - expected_schema.fields().len(), - input_schema.fields().len() - ); + if let Ok(projected) = project_plan_to_schema(Arc::clone(&input), expected_schema) { + if projected.schema().as_ref() == expected_schema.as_ref() { + return Ok(projected); + } } - if input_schema.metadata() != expected_schema.metadata() { - return plan_err!( - "Cannot project plan to expected schema: schema metadata differ" - ); + Ok(Arc::new(SchemaAlignExec::try_new( + input, + Arc::clone(expected_schema), + )?)) +} + +/// Project `input` to `expected_schema` when [`ProjectionExec`] can produce that exact schema. +/// +/// This is a narrower helper than [`align_plan_to_schema`]. It is useful when a positional +/// projection/alias is sufficient. It rejects requests where projection cannot advertise the +/// exact expected schema, such as nullability narrowing. +pub fn project_plan_to_schema( + input: Arc, + expected_schema: &SchemaRef, +) -> Result> { + let input_schema = input.schema(); + validate_schema_alignment(&input_schema, expected_schema, "project")?; + + if input_schema.as_ref() == expected_schema.as_ref() { + return Ok(input); } - if let Some((i, input_field, expected_field, mismatch)) = input_schema + if let Some((i, input_field, expected_field)) = input_schema .fields() .iter() .zip(expected_schema.fields().iter()) .enumerate() .find_map(|(i, (input_field, expected_field))| { - if input_field.data_type() != expected_field.data_type() { - Some((i, input_field, expected_field, "data type")) - } else if input_field.is_nullable() && !expected_field.is_nullable() { - Some((i, input_field, expected_field, "nullability")) - } else if input_field.metadata() != expected_field.metadata() { - Some((i, input_field, expected_field, "metadata")) - } else { - None - } + (input_field.is_nullable() && !expected_field.is_nullable()).then_some(( + i, + input_field, + expected_field, + )) }) { return plan_err!( "Cannot project plan column {i} ('{}') to expected output field '{}': \ - field {mismatch} differs (input field: {:?}, expected field: {:?})", + field nullability differs (input field: {:?}, expected field: {:?})", input_field.name(), expected_field.name(), input_field, @@ -180,6 +199,182 @@ pub fn project_plan_to_schema( Ok(Arc::new(projection)) } +fn validate_schema_alignment( + input_schema: &SchemaRef, + expected_schema: &SchemaRef, + operation: &str, +) -> Result<()> { + if input_schema.fields().len() != expected_schema.fields().len() { + return plan_err!( + "Cannot {operation} plan to expected schema: expected {} column(s), got {}", + expected_schema.fields().len(), + input_schema.fields().len() + ); + } + + if input_schema.metadata() != expected_schema.metadata() { + return plan_err!( + "Cannot {operation} plan to expected schema: schema metadata differ" + ); + } + + if let Some((i, input_field, expected_field, mismatch)) = input_schema + .fields() + .iter() + .zip(expected_schema.fields().iter()) + .enumerate() + .find_map(|(i, (input_field, expected_field))| { + if input_field.data_type() != expected_field.data_type() { + Some((i, input_field, expected_field, "data type")) + } else if input_field.metadata() != expected_field.metadata() { + Some((i, input_field, expected_field, "metadata")) + } else { + None + } + }) + { + return plan_err!( + "Cannot {operation} plan column {i} ('{}') to expected output field '{}': \ + field {mismatch} differs (input field: {:?}, expected field: {:?})", + input_field.name(), + expected_field.name(), + input_field, + expected_field + ); + } + + Ok(()) +} + +/// Plan-time schema adapter for positional schema alignment. +/// +/// [`ProjectionExec`] cannot express every schema-only alignment. In particular, a column +/// expression remains nullable when its input field is nullable, so projection cannot advertise +/// a non-null expected field. This adapter is for cases where the operator-level contract has +/// already established that columns are positionally compatible and the child plan must expose +/// the declared schema exactly. +#[derive(Debug, Clone)] +pub struct SchemaAlignExec { + input: Arc, + schema: SchemaRef, + cache: Arc, +} + +impl SchemaAlignExec { + /// Create a new schema alignment adapter. + pub fn try_new(input: Arc, schema: SchemaRef) -> Result { + validate_schema_alignment(&input.schema(), &schema, "align")?; + + let input_properties = input.properties(); + let partitioning = match &input_properties.partitioning { + Partitioning::RoundRobinBatch(partitions) => { + Partitioning::RoundRobinBatch(*partitions) + } + Partitioning::Hash(_, partitions) + | Partitioning::UnknownPartitioning(partitions) => { + Partitioning::UnknownPartitioning(*partitions) + } + }; + let properties = PlanProperties::new( + EquivalenceProperties::new(Arc::clone(&schema)), + partitioning, + input_properties.emission_type, + input_properties.boundedness, + ) + .with_evaluation_type(input_properties.evaluation_type) + .with_scheduling_type(input_properties.scheduling_type); + + Ok(Self { + input, + schema, + cache: Arc::new(properties), + }) + } + + /// Input plan being aligned. + pub fn input(&self) -> &Arc { + &self.input + } +} + +impl DisplayAs for SchemaAlignExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "SchemaAlignExec") + } + DisplayFormatType::TreeRender => write!(f, ""), + } + } +} + +impl ExecutionPlan for SchemaAlignExec { + fn name(&self) -> &'static str { + "SchemaAlignExec" + } + + fn properties(&self) -> &Arc { + &self.cache + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn apply_expressions( + &self, + _f: &mut dyn FnMut(&dyn PhysicalExpr) -> Result, + ) -> Result { + Ok(TreeNodeRecursion::Continue) + } + + fn maintains_input_order(&self) -> Vec { + vec![true] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + let [input] = children.try_into().map_err(|children: Vec<_>| { + datafusion_common::DataFusionError::Internal(format!( + "SchemaAlignExec expected 1 child, got {}", + children.len() + )) + })?; + Ok(Arc::new(Self::try_new(input, Arc::clone(&self.schema))?)) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + let schema = Arc::clone(&self.schema); + let stream = self.input.execute(partition, context)?.map({ + let schema = Arc::clone(&schema); + move |batch| { + let batch = batch?; + if batch.schema().as_ref() == schema.as_ref() { + Ok(batch) + } else { + RecordBatch::try_new(Arc::clone(&schema), batch.columns().to_vec()) + .map_err(Into::into) + } + } + }); + Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream))) + } + + fn partition_statistics(&self, partition: Option) -> Result> { + self.input.partition_statistics(partition) + } +} + /// If running in a tokio context spawns the execution of `stream` to a separate task /// allowing it to execute in parallel with an intermediate buffer of size `buffer` pub fn spawn_buffered( @@ -523,6 +718,25 @@ mod tests { assert!(err.to_string().contains("field nullability differs")); } + #[test] + fn align_plan_to_schema_uses_adapter_for_nullability_narrowing() -> Result<()> { + let input = empty_exec(vec![Field::new("a", DataType::Int32, true)]); + let expected_schema = Arc::new(Schema::new(vec![Field::new( + "renamed", + DataType::Int32, + false, + )])); + + let result = align_plan_to_schema(Arc::clone(&input), &expected_schema)?; + + let aligned = result + .downcast_ref::() + .expect("nullability narrowing should use SchemaAlignExec"); + assert!(Arc::ptr_eq(aligned.input(), &input)); + assert_eq!(aligned.schema(), expected_schema); + Ok(()) + } + #[test] fn project_plan_to_schema_errors_on_field_metadata_mismatch() { let input = diff --git a/datafusion/physical-plan/src/recursive_query.rs b/datafusion/physical-plan/src/recursive_query.rs index c160f9a0dc763..337e5d3d300e4 100644 --- a/datafusion/physical-plan/src/recursive_query.rs +++ b/datafusion/physical-plan/src/recursive_query.rs @@ -24,7 +24,7 @@ use std::task::{Context, Poll}; use super::work_table::{ReservedBatches, WorkTable}; use crate::aggregates::group_values::{GroupValues, new_group_values}; use crate::aggregates::order::GroupOrdering; -use crate::common::project_plan_to_schema; +use crate::common::align_plan_to_schema; use crate::execution_plan::{Boundedness, EmissionType, reset_plan_states}; use crate::metrics::{ BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RecordOutput, @@ -35,7 +35,7 @@ use crate::{ }; use arrow::array::{BooleanArray, BooleanBuilder}; use arrow::compute::filter_record_batch; -use arrow::datatypes::{Field, Schema, SchemaRef}; +use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use datafusion_common::tree_node::TreeNodeRecursion; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; @@ -90,13 +90,13 @@ impl RecursiveQueryExec { ) -> Result { // Each recursive query needs its own work table let work_table = Arc::new(WorkTable::new(name.clone())); - // Use the same work table for both the WorkTableExec and the recursive term - let output_schema = - recursive_output_schema(&static_term.schema(), &recursive_term.schema()); - let static_term = project_plan_to_schema(static_term, &output_schema)?; + // Use the static term as the declared recursive CTE output schema. The + // recursive term is planned independently, so align it at plan construction + // time instead of patching batches in RecursiveQueryStream. + let output_schema = static_term.schema(); let recursive_term = assign_work_table(recursive_term, &work_table)?; - let recursive_term = project_plan_to_schema(recursive_term, &output_schema)?; - let cache = Self::compute_properties(output_schema); + let recursive_term = align_plan_to_schema(recursive_term, &output_schema)?; + let cache = Self::compute_properties(Arc::clone(&output_schema)); Ok(RecursiveQueryExec { name, static_term, @@ -370,30 +370,6 @@ impl RecursiveQueryStream { } } -fn recursive_output_schema( - static_schema: &SchemaRef, - recursive_schema: &SchemaRef, -) -> SchemaRef { - let fields = static_schema - .fields() - .iter() - .zip(recursive_schema.fields()) - .map(|(static_field, recursive_field)| { - Field::new( - static_field.name(), - static_field.data_type().clone(), - static_field.is_nullable() || recursive_field.is_nullable(), - ) - .with_metadata(static_field.metadata().clone()) - }) - .collect::>(); - - Arc::new(Schema::new_with_metadata( - fields, - static_schema.metadata().clone(), - )) -} - fn assign_work_table( plan: Arc, work_table: &Arc, @@ -520,6 +496,7 @@ fn new_groups_mask( #[cfg(test)] mod tests { use super::*; + use crate::common::SchemaAlignExec; use crate::empty::EmptyExec; use crate::projection::ProjectionExec; @@ -554,21 +531,27 @@ mod tests { } #[test] - fn recursive_query_exec_reconciles_nullability() -> Result<()> { + fn recursive_query_exec_preserves_static_nullability_contract() -> Result<()> { let static_term = empty_exec(vec![Field::new("value", DataType::Int32, false)]); let recursive_term = empty_exec(vec![Field::new("value + Int32(1)", DataType::Int32, true)]); let exec = RecursiveQueryExec::try_new( "numbers".to_string(), - static_term, - recursive_term, + Arc::clone(&static_term), + Arc::clone(&recursive_term), false, )?; - assert!(exec.schema().field(0).is_nullable()); - assert!(exec.static_term().schema().field(0).is_nullable()); - assert!(exec.recursive_term().schema().field(0).is_nullable()); + assert_eq!(exec.schema(), static_term.schema()); + assert_eq!(exec.static_term().schema(), static_term.schema()); + assert_eq!(exec.recursive_term().schema(), static_term.schema()); + assert!(!exec.schema().field(0).is_nullable()); + let aligned = exec + .recursive_term() + .downcast_ref::() + .expect("nullable recursive term should be aligned with SchemaAlignExec"); + assert!(Arc::ptr_eq(aligned.input(), &recursive_term)); Ok(()) } } diff --git a/datafusion/sqllogictest/test_files/cte.slt b/datafusion/sqllogictest/test_files/cte.slt index d13e0d4f085e9..bb5a18d53d82d 100644 --- a/datafusion/sqllogictest/test_files/cte.slt +++ b/datafusion/sqllogictest/test_files/cte.slt @@ -699,7 +699,7 @@ WITH RECURSIVE region_sales AS ( SELECT s.salesperson_id AS salesperson_id, SUM(s.sale_amount) AS amount, - SUM(0) as level + 0 as level FROM sales s GROUP BY From f191f7cf486305b0321dc8afd5240f3a52ba2801 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Wed, 6 May 2026 11:31:16 +0800 Subject: [PATCH 02/24] feat(datafusion): add direct tests for align_plan_to_schema and document behavior - Added direct tests for align_plan_to_schema: - Verified exact schema returns the same plan. - Ensured rename-only uses ProjectionExec. - Confirmed nullability narrowing uses SchemaAlignExec. - Tested count/type/field metadata/schema metadata errors. - Documented conservative property behavior in the adapter path. --- datafusion/physical-plan/src/common.rs | 89 ++++++++++++++++++++++++++ 1 file changed, 89 insertions(+) diff --git a/datafusion/physical-plan/src/common.rs b/datafusion/physical-plan/src/common.rs index 000d41d03ebb5..77991349e0d3b 100644 --- a/datafusion/physical-plan/src/common.rs +++ b/datafusion/physical-plan/src/common.rs @@ -110,6 +110,10 @@ fn build_file_list_recurse( /// This helper can align field names and nullability to the declared schema. It rejects /// differences that would change values or silently lose schema information, such as column /// count, data type, field metadata, or schema metadata mismatches. +/// +/// When an adapter is required, it conservatively derives fresh equivalence properties from +/// `expected_schema` and drops child hash partitioning because field names/nullability may have +/// changed while the underlying partitioning expressions still refer to the child schema. pub fn align_plan_to_schema( input: Arc, expected_schema: &SchemaRef, @@ -718,6 +722,37 @@ mod tests { assert!(err.to_string().contains("field nullability differs")); } + #[test] + fn align_plan_to_schema_returns_input_when_schema_matches() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new( + "value", + DataType::Int32, + false, + )])); + let input: Arc = Arc::new(EmptyExec::new(Arc::clone(&schema))); + + let result = align_plan_to_schema(Arc::clone(&input), &schema)?; + + assert!(Arc::ptr_eq(&input, &result)); + Ok(()) + } + + #[test] + fn align_plan_to_schema_uses_projection_for_rename_only() -> Result<()> { + let input = empty_exec(vec![Field::new("recursive_a", DataType::Int32, false)]); + let expected_schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + + let result = align_plan_to_schema(Arc::clone(&input), &expected_schema)?; + + let projection = result + .downcast_ref::() + .expect("rename-only alignment should use ProjectionExec"); + assert!(Arc::ptr_eq(projection.input(), &input)); + assert_eq!(projection.schema(), expected_schema); + Ok(()) + } + #[test] fn align_plan_to_schema_uses_adapter_for_nullability_narrowing() -> Result<()> { let input = empty_exec(vec![Field::new("a", DataType::Int32, true)]); @@ -737,6 +772,60 @@ mod tests { Ok(()) } + #[test] + fn align_plan_to_schema_errors_on_column_count_mismatch() { + let input = empty_exec(vec![Field::new("a", DataType::Int32, false)]); + let expected_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ])); + + let err = align_plan_to_schema(input, &expected_schema).unwrap_err(); + assert!(err.to_string().contains("expected 2 column")); + } + + #[test] + fn align_plan_to_schema_errors_on_type_mismatch() { + let input = empty_exec(vec![Field::new("a", DataType::Int32, false)]); + let expected_schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, false)])); + + let err = align_plan_to_schema(input, &expected_schema).unwrap_err(); + assert!(err.to_string().contains("field data type differs")); + } + + #[test] + fn align_plan_to_schema_errors_on_field_metadata_mismatch() { + let input = + empty_exec(vec![Field::new("a", DataType::Int32, false).with_metadata( + HashMap::from([("source".to_string(), "input".to_string())]), + )]); + let expected_schema = Arc::new(Schema::new(vec![ + Field::new("renamed", DataType::Int32, false).with_metadata(HashMap::from([ + ("source".to_string(), "expected".to_string()), + ])), + ])); + + let err = align_plan_to_schema(input, &expected_schema).unwrap_err(); + assert!(err.to_string().contains("field metadata differs")); + } + + #[test] + fn align_plan_to_schema_errors_on_schema_metadata_mismatch() { + let input_schema = Arc::new(Schema::new_with_metadata( + vec![Field::new("a", DataType::Int32, false)], + HashMap::from([("source".to_string(), "input".to_string())]), + )); + let input: Arc = Arc::new(EmptyExec::new(input_schema)); + let expected_schema = Arc::new(Schema::new_with_metadata( + vec![Field::new("renamed", DataType::Int32, false)], + HashMap::from([("source".to_string(), "expected".to_string())]), + )); + + let err = align_plan_to_schema(input, &expected_schema).unwrap_err(); + assert!(err.to_string().contains("schema metadata differ")); + } + #[test] fn project_plan_to_schema_errors_on_field_metadata_mismatch() { let input = From 1b39805849488c1f60d905e54159a9cf7c8d85df Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Wed, 6 May 2026 11:36:19 +0800 Subject: [PATCH 03/24] feat: improve schema alignment checks in execution plans - Refactored `align_plan_to_schema` function to store input schema in a variable, reducing redundant calls. - Updated validation and comparison logic for better clarity and performance. - Simplified partitioning handling in `SchemaAlignExec` by consolidating pattern matching. - Enhanced `DisplayAs` implementation to correctly handle `TreeRender` format. --- datafusion/physical-plan/src/common.rs | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/datafusion/physical-plan/src/common.rs b/datafusion/physical-plan/src/common.rs index 77991349e0d3b..dd219e95548fe 100644 --- a/datafusion/physical-plan/src/common.rs +++ b/datafusion/physical-plan/src/common.rs @@ -118,16 +118,16 @@ pub fn align_plan_to_schema( input: Arc, expected_schema: &SchemaRef, ) -> Result> { - validate_schema_alignment(&input.schema(), expected_schema, "align")?; + let input_schema = input.schema(); + validate_schema_alignment(&input_schema, expected_schema, "align")?; - if input.schema().as_ref() == expected_schema.as_ref() { + if input_schema.as_ref() == expected_schema.as_ref() { return Ok(input); } if let Ok(projected) = project_plan_to_schema(Arc::clone(&input), expected_schema) { - if projected.schema().as_ref() == expected_schema.as_ref() { - return Ok(projected); - } + debug_assert_eq!(projected.schema().as_ref(), expected_schema.as_ref()); + return Ok(projected); } Ok(Arc::new(SchemaAlignExec::try_new( @@ -274,10 +274,7 @@ impl SchemaAlignExec { Partitioning::RoundRobinBatch(partitions) => { Partitioning::RoundRobinBatch(*partitions) } - Partitioning::Hash(_, partitions) - | Partitioning::UnknownPartitioning(partitions) => { - Partitioning::UnknownPartitioning(*partitions) - } + partitioning => Partitioning::UnknownPartitioning(partitioning.partition_count()), }; let properties = PlanProperties::new( EquivalenceProperties::new(Arc::clone(&schema)), @@ -311,7 +308,7 @@ impl DisplayAs for SchemaAlignExec { DisplayFormatType::Default | DisplayFormatType::Verbose => { write!(f, "SchemaAlignExec") } - DisplayFormatType::TreeRender => write!(f, ""), + DisplayFormatType::TreeRender => Ok(()), } } } From c0a606600f98f3e186c7dea43b06a55678125fb8 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Wed, 6 May 2026 11:37:26 +0800 Subject: [PATCH 04/24] feat: Improve performance and clarity in common and recursive query modules - Reuse `input_schema` in common.rs - Simplify projected return using `debug_assert_eq!` - Utilize `partition_count()` in common.rs - Modify TreeRender to return `Ok(())` - Reuse `static_schema` in tests for recursive_query.rs --- datafusion/physical-plan/src/common.rs | 4 +++- datafusion/physical-plan/src/recursive_query.rs | 7 ++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/datafusion/physical-plan/src/common.rs b/datafusion/physical-plan/src/common.rs index dd219e95548fe..cae578a27d5a5 100644 --- a/datafusion/physical-plan/src/common.rs +++ b/datafusion/physical-plan/src/common.rs @@ -274,7 +274,9 @@ impl SchemaAlignExec { Partitioning::RoundRobinBatch(partitions) => { Partitioning::RoundRobinBatch(*partitions) } - partitioning => Partitioning::UnknownPartitioning(partitioning.partition_count()), + partitioning => { + Partitioning::UnknownPartitioning(partitioning.partition_count()) + } }; let properties = PlanProperties::new( EquivalenceProperties::new(Arc::clone(&schema)), diff --git a/datafusion/physical-plan/src/recursive_query.rs b/datafusion/physical-plan/src/recursive_query.rs index 337e5d3d300e4..313a5e86e4879 100644 --- a/datafusion/physical-plan/src/recursive_query.rs +++ b/datafusion/physical-plan/src/recursive_query.rs @@ -543,9 +543,10 @@ mod tests { false, )?; - assert_eq!(exec.schema(), static_term.schema()); - assert_eq!(exec.static_term().schema(), static_term.schema()); - assert_eq!(exec.recursive_term().schema(), static_term.schema()); + let static_schema = static_term.schema(); + assert_eq!(exec.schema(), static_schema); + assert_eq!(exec.static_term().schema(), static_schema); + assert_eq!(exec.recursive_term().schema(), static_schema); assert!(!exec.schema().field(0).is_nullable()); let aligned = exec .recursive_term() From 7fa72c2649d7f389adf8ac98ee8a7da978d3b99b Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Wed, 6 May 2026 11:58:00 +0800 Subject: [PATCH 05/24] feat: improve test setup and simplify validation in physical plan - Removed redundant upfront align validation in common.rs. - Added test helpers in common.rs: - single_field_schema - single_i32_exec - metadata mismatch builders - Shortened repeated test setup in common.rs. - Added recursive_exec test helper in recursive_query.rs. - Simplified RecursiveQueryExec::try_new(...) in recursive_query.rs. --- datafusion/physical-plan/src/common.rs | 130 ++++++++---------- .../physical-plan/src/recursive_query.rs | 26 ++-- 2 files changed, 68 insertions(+), 88 deletions(-) diff --git a/datafusion/physical-plan/src/common.rs b/datafusion/physical-plan/src/common.rs index cae578a27d5a5..b3ab903f02228 100644 --- a/datafusion/physical-plan/src/common.rs +++ b/datafusion/physical-plan/src/common.rs @@ -119,7 +119,6 @@ pub fn align_plan_to_schema( expected_schema: &SchemaRef, ) -> Result> { let input_schema = input.schema(); - validate_schema_alignment(&input_schema, expected_schema, "align")?; if input_schema.as_ref() == expected_schema.as_ref() { return Ok(input); @@ -507,6 +506,40 @@ mod tests { Arc::new(EmptyExec::new(Arc::new(Schema::new(fields)))) } + fn single_field_schema(name: &str, data_type: DataType, nullable: bool) -> SchemaRef { + Arc::new(Schema::new(vec![Field::new(name, data_type, nullable)])) + } + + fn single_i32_exec(name: &str, nullable: bool) -> Arc { + empty_exec(vec![Field::new(name, DataType::Int32, nullable)]) + } + + fn field_metadata_mismatch() -> (Arc, SchemaRef) { + let input = + empty_exec(vec![Field::new("a", DataType::Int32, false).with_metadata( + HashMap::from([("source".to_string(), "input".to_string())]), + )]); + let expected_schema = Arc::new(Schema::new(vec![ + Field::new("renamed", DataType::Int32, false).with_metadata(HashMap::from([ + ("source".to_string(), "expected".to_string()), + ])), + ])); + (input, expected_schema) + } + + fn schema_metadata_mismatch() -> (Arc, SchemaRef) { + let input_schema = Arc::new(Schema::new_with_metadata( + vec![Field::new("a", DataType::Int32, false)], + HashMap::from([("source".to_string(), "input".to_string())]), + )); + let input: Arc = Arc::new(EmptyExec::new(input_schema)); + let expected_schema = Arc::new(Schema::new_with_metadata( + vec![Field::new("renamed", DataType::Int32, false)], + HashMap::from([("source".to_string(), "expected".to_string())]), + )); + (input, expected_schema) + } + #[test] fn test_compute_record_batch_statistics_empty() -> Result<()> { let schema = Arc::new(Schema::new(vec![ @@ -608,11 +641,7 @@ mod tests { #[test] fn project_plan_to_schema_returns_input_when_schema_matches() -> Result<()> { - let schema = Arc::new(Schema::new(vec![Field::new( - "value", - DataType::Int32, - false, - )])); + let schema = single_field_schema("value", DataType::Int32, false); let input: Arc = Arc::new(EmptyExec::new(Arc::clone(&schema))); let result = project_plan_to_schema(Arc::clone(&input), &schema)?; @@ -673,7 +702,7 @@ mod tests { #[test] fn project_plan_to_schema_errors_on_column_count_mismatch() { - let input = empty_exec(vec![Field::new("a", DataType::Int32, false)]); + let input = single_i32_exec("a", false); let expected_schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Int32, false), Field::new("b", DataType::Int32, false), @@ -685,9 +714,8 @@ mod tests { #[test] fn project_plan_to_schema_errors_on_type_mismatch() { - let input = empty_exec(vec![Field::new("a", DataType::Int32, false)]); - let expected_schema = - Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, false)])); + let input = single_i32_exec("a", false); + let expected_schema = single_field_schema("a", DataType::Float32, false); let err = project_plan_to_schema(input, &expected_schema).unwrap_err(); assert!(err.to_string().contains("field data type differs")); @@ -695,12 +723,8 @@ mod tests { #[test] fn project_plan_to_schema_widens_nullability() -> Result<()> { - let input = empty_exec(vec![Field::new("a", DataType::Int32, false)]); - let expected_schema = Arc::new(Schema::new(vec![Field::new( - "renamed", - DataType::Int32, - true, - )])); + let input = single_i32_exec("a", false); + let expected_schema = single_field_schema("renamed", DataType::Int32, true); let result = project_plan_to_schema(input, &expected_schema)?; @@ -710,12 +734,8 @@ mod tests { #[test] fn project_plan_to_schema_errors_on_nullability_narrowing() { - let input = empty_exec(vec![Field::new("a", DataType::Int32, true)]); - let expected_schema = Arc::new(Schema::new(vec![Field::new( - "renamed", - DataType::Int32, - false, - )])); + let input = single_i32_exec("a", true); + let expected_schema = single_field_schema("renamed", DataType::Int32, false); let err = project_plan_to_schema(input, &expected_schema).unwrap_err(); assert!(err.to_string().contains("field nullability differs")); @@ -723,11 +743,7 @@ mod tests { #[test] fn align_plan_to_schema_returns_input_when_schema_matches() -> Result<()> { - let schema = Arc::new(Schema::new(vec![Field::new( - "value", - DataType::Int32, - false, - )])); + let schema = single_field_schema("value", DataType::Int32, false); let input: Arc = Arc::new(EmptyExec::new(Arc::clone(&schema))); let result = align_plan_to_schema(Arc::clone(&input), &schema)?; @@ -738,9 +754,8 @@ mod tests { #[test] fn align_plan_to_schema_uses_projection_for_rename_only() -> Result<()> { - let input = empty_exec(vec![Field::new("recursive_a", DataType::Int32, false)]); - let expected_schema = - Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let input = single_i32_exec("recursive_a", false); + let expected_schema = single_field_schema("a", DataType::Int32, false); let result = align_plan_to_schema(Arc::clone(&input), &expected_schema)?; @@ -754,12 +769,8 @@ mod tests { #[test] fn align_plan_to_schema_uses_adapter_for_nullability_narrowing() -> Result<()> { - let input = empty_exec(vec![Field::new("a", DataType::Int32, true)]); - let expected_schema = Arc::new(Schema::new(vec![Field::new( - "renamed", - DataType::Int32, - false, - )])); + let input = single_i32_exec("a", true); + let expected_schema = single_field_schema("renamed", DataType::Int32, false); let result = align_plan_to_schema(Arc::clone(&input), &expected_schema)?; @@ -773,7 +784,7 @@ mod tests { #[test] fn align_plan_to_schema_errors_on_column_count_mismatch() { - let input = empty_exec(vec![Field::new("a", DataType::Int32, false)]); + let input = single_i32_exec("a", false); let expected_schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Int32, false), Field::new("b", DataType::Int32, false), @@ -785,9 +796,8 @@ mod tests { #[test] fn align_plan_to_schema_errors_on_type_mismatch() { - let input = empty_exec(vec![Field::new("a", DataType::Int32, false)]); - let expected_schema = - Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, false)])); + let input = single_i32_exec("a", false); + let expected_schema = single_field_schema("a", DataType::Float32, false); let err = align_plan_to_schema(input, &expected_schema).unwrap_err(); assert!(err.to_string().contains("field data type differs")); @@ -795,15 +805,7 @@ mod tests { #[test] fn align_plan_to_schema_errors_on_field_metadata_mismatch() { - let input = - empty_exec(vec![Field::new("a", DataType::Int32, false).with_metadata( - HashMap::from([("source".to_string(), "input".to_string())]), - )]); - let expected_schema = Arc::new(Schema::new(vec![ - Field::new("renamed", DataType::Int32, false).with_metadata(HashMap::from([ - ("source".to_string(), "expected".to_string()), - ])), - ])); + let (input, expected_schema) = field_metadata_mismatch(); let err = align_plan_to_schema(input, &expected_schema).unwrap_err(); assert!(err.to_string().contains("field metadata differs")); @@ -811,15 +813,7 @@ mod tests { #[test] fn align_plan_to_schema_errors_on_schema_metadata_mismatch() { - let input_schema = Arc::new(Schema::new_with_metadata( - vec![Field::new("a", DataType::Int32, false)], - HashMap::from([("source".to_string(), "input".to_string())]), - )); - let input: Arc = Arc::new(EmptyExec::new(input_schema)); - let expected_schema = Arc::new(Schema::new_with_metadata( - vec![Field::new("renamed", DataType::Int32, false)], - HashMap::from([("source".to_string(), "expected".to_string())]), - )); + let (input, expected_schema) = schema_metadata_mismatch(); let err = align_plan_to_schema(input, &expected_schema).unwrap_err(); assert!(err.to_string().contains("schema metadata differ")); @@ -827,15 +821,7 @@ mod tests { #[test] fn project_plan_to_schema_errors_on_field_metadata_mismatch() { - let input = - empty_exec(vec![Field::new("a", DataType::Int32, false).with_metadata( - HashMap::from([("source".to_string(), "input".to_string())]), - )]); - let expected_schema = Arc::new(Schema::new(vec![ - Field::new("renamed", DataType::Int32, false).with_metadata(HashMap::from([ - ("source".to_string(), "expected".to_string()), - ])), - ])); + let (input, expected_schema) = field_metadata_mismatch(); let err = project_plan_to_schema(input, &expected_schema).unwrap_err(); assert!(err.to_string().contains("field metadata differs")); @@ -843,15 +829,7 @@ mod tests { #[test] fn project_plan_to_schema_errors_on_schema_metadata_mismatch() { - let input_schema = Arc::new(Schema::new_with_metadata( - vec![Field::new("a", DataType::Int32, false)], - HashMap::from([("source".to_string(), "input".to_string())]), - )); - let input: Arc = Arc::new(EmptyExec::new(input_schema)); - let expected_schema = Arc::new(Schema::new_with_metadata( - vec![Field::new("renamed", DataType::Int32, false)], - HashMap::from([("source".to_string(), "expected".to_string())]), - )); + let (input, expected_schema) = schema_metadata_mismatch(); let err = project_plan_to_schema(input, &expected_schema).unwrap_err(); assert!(err.to_string().contains("schema metadata differ")); diff --git a/datafusion/physical-plan/src/recursive_query.rs b/datafusion/physical-plan/src/recursive_query.rs index 313a5e86e4879..5c79750332cca 100644 --- a/datafusion/physical-plan/src/recursive_query.rs +++ b/datafusion/physical-plan/src/recursive_query.rs @@ -506,18 +506,25 @@ mod tests { Arc::new(EmptyExec::new(Arc::new(Schema::new(fields)))) } + fn recursive_exec( + static_term: Arc, + recursive_term: Arc, + ) -> Result { + RecursiveQueryExec::try_new( + "numbers".to_string(), + static_term, + recursive_term, + false, + ) + } + #[test] fn recursive_query_exec_projects_recursive_term_to_reconciled_schema() -> Result<()> { let static_term = empty_exec(vec![Field::new("value", DataType::Int32, false)]); let recursive_term = empty_exec(vec![Field::new("value + Int32(1)", DataType::Int32, false)]); - let exec = RecursiveQueryExec::try_new( - "numbers".to_string(), - Arc::clone(&static_term), - Arc::clone(&recursive_term), - false, - )?; + let exec = recursive_exec(Arc::clone(&static_term), Arc::clone(&recursive_term))?; assert_eq!(exec.schema(), static_term.schema()); let projection = exec @@ -536,12 +543,7 @@ mod tests { let recursive_term = empty_exec(vec![Field::new("value + Int32(1)", DataType::Int32, true)]); - let exec = RecursiveQueryExec::try_new( - "numbers".to_string(), - Arc::clone(&static_term), - Arc::clone(&recursive_term), - false, - )?; + let exec = recursive_exec(Arc::clone(&static_term), Arc::clone(&recursive_term))?; let static_schema = static_term.schema(); assert_eq!(exec.schema(), static_schema); From 4959978ded76b3e5407381b4ef80a2f596e584fb Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Sun, 10 May 2026 22:23:45 +0800 Subject: [PATCH 06/24] feat: enhance recursive query functionality with new schema handling and tests - Added `schema: DFSchemaRef` to `RecursiveQuery`. - Updated `LogicalPlan::RecursiveQuery.schema()` to return the stored schema. - Introduced `RecursiveQuery::try_new(...)` for schema derivation based on static anchor field names, qualifiers, data types, nullability, and intersected metadata. - Implemented manual `PartialOrd` for `RecursiveQuery`. - Modified `to_recursive_query` to utilize `RecursiveQuery::try_new(...)`. - Added unit test for widening nullability in recursive query schema. - Ensured `RecursiveQuery` rebuilds correctly after child transforms using `try_new(...)`. - Updated deserialization of `RecursiveQuery` to leverage `try_new(...)`. - Enhanced `RecursiveQueryExec::try_new` to derive widened output schema using static and recursive schemas. - Introduced a helper function for generating recursive query output schema. - Updated tests for executive schema handling of recursive nullable outputs. - Added a SQL regression test to verify recursive term behavior and expected output. --- datafusion/expr/src/logical_plan/builder.rs | 31 +++++-- datafusion/expr/src/logical_plan/plan.rs | 92 ++++++++++++++++--- datafusion/expr/src/logical_plan/tree_node.rs | 13 +-- .../physical-plan/src/recursive_query.rs | 71 +++++++++++--- datafusion/proto/src/logical_plan/mod.rs | 12 +-- datafusion/sqllogictest/test_files/cte.slt | 13 +++ 6 files changed, 187 insertions(+), 45 deletions(-) diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 017a123eb035b..745ffe1b6fb19 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -192,12 +192,14 @@ impl LogicalPlanBuilder { // Ensure that the recursive term has the same field types as the static term let coerced_recursive_term = coerce_plan_expr_for_schema(recursive_term, self.plan.schema())?; - Ok(Self::from(LogicalPlan::RecursiveQuery(RecursiveQuery { - name, - static_term: self.plan, - recursive_term: Arc::new(coerced_recursive_term), - is_distinct, - }))) + Ok(Self::from(LogicalPlan::RecursiveQuery( + RecursiveQuery::try_new( + name, + self.plan, + Arc::new(coerced_recursive_term), + is_distinct, + )?, + ))) } /// Create a values list based relation, and the schema is inferred from data, consuming @@ -2367,6 +2369,23 @@ mod tests { Ok(()) } + #[test] + fn recursive_query_schema_widens_nullability_from_recursive_term() -> Result<()> { + let static_term = LogicalPlanBuilder::empty(true) + .project(vec![lit(0i32).alias("n")])?; + let recursive_term = LogicalPlanBuilder::empty(true) + .project(vec![lit(ScalarValue::Int32(None)).alias("recursive_n")])? + .build()?; + + let plan = static_term + .to_recursive_query("t".to_string(), recursive_term, false)? + .build()?; + + assert_eq!(plan.schema().field(0).name(), "n"); + assert!(plan.schema().field(0).is_nullable()); + Ok(()) + } + #[test] fn plan_builder_union() -> Result<()> { let plan = diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index db8b82fe87a14..d4fa2f2c145ff 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -353,10 +353,7 @@ impl LogicalPlan { LogicalPlan::Copy(CopyTo { output_schema, .. }) => output_schema, LogicalPlan::Ddl(ddl) => ddl.schema(), LogicalPlan::Unnest(Unnest { schema, .. }) => schema, - LogicalPlan::RecursiveQuery(RecursiveQuery { static_term, .. }) => { - // we take the schema of the static term as the schema of the entire recursive query - static_term.schema() - } + LogicalPlan::RecursiveQuery(RecursiveQuery { schema, .. }) => schema, } } @@ -1080,12 +1077,12 @@ impl LogicalPlan { }) => { self.assert_no_expressions(expr)?; let (static_term, recursive_term) = self.only_two_inputs(inputs)?; - Ok(LogicalPlan::RecursiveQuery(RecursiveQuery { - name: name.clone(), - static_term: Arc::new(static_term), - recursive_term: Arc::new(recursive_term), - is_distinct: *is_distinct, - })) + Ok(LogicalPlan::RecursiveQuery(RecursiveQuery::try_new( + name.clone(), + Arc::new(static_term), + Arc::new(recursive_term), + *is_distinct, + )?)) } LogicalPlan::Analyze(a) => { self.assert_no_expressions(expr)?; @@ -2246,7 +2243,7 @@ impl PartialOrd for EmptyRelation { /// intermediate table, then empty the intermediate table. /// /// [Postgres Docs]: https://www.postgresql.org/docs/current/queries-with.html#QUERIES-WITH-RECURSIVE -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct RecursiveQuery { /// Name of the query pub name: String, @@ -2255,11 +2252,84 @@ pub struct RecursiveQuery { /// The recursive term (evaluated on the contents of the working table until /// it returns an empty set) pub recursive_term: Arc, + /// Output schema, using static term field names and nullability widened + /// across both static and recursive terms. + pub schema: DFSchemaRef, /// Should the output of the recursive term be deduplicated (`UNION`) or /// not (`UNION ALL`). pub is_distinct: bool, } +impl RecursiveQuery { + /// Create a recursive query with an output schema using static term field names + /// and nullability widened across both static and recursive terms. + pub fn try_new( + name: String, + static_term: Arc, + recursive_term: Arc, + is_distinct: bool, + ) -> Result { + let schema = recursive_query_schema(static_term.schema(), recursive_term.schema())?; + Ok(Self { + name, + static_term, + recursive_term, + schema, + is_distinct, + }) + } +} + +fn recursive_query_schema( + static_schema: &DFSchema, + recursive_schema: &DFSchema, +) -> Result { + let fields = static_schema + .fields() + .iter() + .zip(recursive_schema.fields().iter()) + .enumerate() + .map(|(i, (static_field, recursive_field))| { + let (qualifier, _) = static_schema.qualified_field(i); + let mut field = Field::new( + static_field.name(), + static_field.data_type().clone(), + static_field.is_nullable() || recursive_field.is_nullable(), + ); + field.set_metadata(intersect_metadata_for_union([ + static_field.metadata(), + recursive_field.metadata(), + ])); + Ok((qualifier.cloned(), Arc::new(field))) + }) + .collect::>>()?; + + let metadata = intersect_metadata_for_union([ + static_schema.metadata(), + recursive_schema.metadata(), + ]); + Ok(Arc::new(DFSchema::new_with_metadata(fields, metadata)?)) +} + +// Manual implementation needed because of `schema` field. Comparison excludes this field. +impl PartialOrd for RecursiveQuery { + fn partial_cmp(&self, other: &Self) -> Option { + ( + &self.name, + &self.static_term, + &self.recursive_term, + self.is_distinct, + ) + .partial_cmp(&( + &other.name, + &other.static_term, + &other.recursive_term, + other.is_distinct, + )) + .filter(|cmp| *cmp != Ordering::Equal || self == other) + } +} + /// Values expression. See /// [Postgres VALUES](https://www.postgresql.org/docs/current/queries-values.html) /// documentation for more details. diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index ef9382a57209a..0385637edd576 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -329,16 +329,13 @@ impl TreeNode for LogicalPlan { static_term, recursive_term, is_distinct, - }) => (static_term, recursive_term).map_elements(f)?.update_data( + .. + }) => (static_term, recursive_term).map_elements(f)?.map_data( |(static_term, recursive_term)| { - LogicalPlan::RecursiveQuery(RecursiveQuery { - name, - static_term, - recursive_term, - is_distinct, - }) + RecursiveQuery::try_new(name, static_term, recursive_term, is_distinct) + .map(LogicalPlan::RecursiveQuery) }, - ), + )?, LogicalPlan::Statement(stmt) => match stmt { Statement::Prepare(p) => p .input diff --git a/datafusion/physical-plan/src/recursive_query.rs b/datafusion/physical-plan/src/recursive_query.rs index 5c79750332cca..062774f90a796 100644 --- a/datafusion/physical-plan/src/recursive_query.rs +++ b/datafusion/physical-plan/src/recursive_query.rs @@ -35,7 +35,7 @@ use crate::{ }; use arrow::array::{BooleanArray, BooleanBuilder}; use arrow::compute::filter_record_batch; -use arrow::datatypes::SchemaRef; +use arrow::datatypes::{Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_common::tree_node::TreeNodeRecursion; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; @@ -93,8 +93,12 @@ impl RecursiveQueryExec { // Use the static term as the declared recursive CTE output schema. The // recursive term is planned independently, so align it at plan construction // time instead of patching batches in RecursiveQueryStream. - let output_schema = static_term.schema(); let recursive_term = assign_work_table(recursive_term, &work_table)?; + let output_schema = recursive_query_output_schema( + &static_term.schema(), + &recursive_term.schema(), + )?; + let static_term = align_plan_to_schema(static_term, &output_schema)?; let recursive_term = align_plan_to_schema(recursive_term, &output_schema)?; let cache = Self::compute_properties(Arc::clone(&output_schema)); Ok(RecursiveQueryExec { @@ -370,6 +374,47 @@ impl RecursiveQueryStream { } } +fn recursive_query_output_schema( + static_schema: &SchemaRef, + recursive_schema: &SchemaRef, +) -> Result { + if static_schema.fields().len() != recursive_schema.fields().len() { + return datafusion_common::plan_err!( + "RecursiveQueryExec static and recursive terms have different number of columns: {} != {}", + static_schema.fields().len(), + recursive_schema.fields().len() + ); + } + + let fields = static_schema + .fields() + .iter() + .zip(recursive_schema.fields().iter()) + .enumerate() + .map(|(i, (static_field, recursive_field))| { + if static_field.data_type() != recursive_field.data_type() { + return datafusion_common::plan_err!( + "RecursiveQueryExec column {i} has different types: static term has {} whereas recursive term has {}", + static_field.data_type(), + recursive_field.data_type() + ); + } + let mut field = Field::new( + static_field.name(), + static_field.data_type().clone(), + static_field.is_nullable() || recursive_field.is_nullable(), + ); + field.set_metadata(static_field.metadata().clone()); + Ok(Arc::new(field)) + }) + .collect::>>()?; + + Ok(Arc::new(Schema::new_with_metadata( + fields, + static_schema.metadata().clone(), + ))) +} + fn assign_work_table( plan: Arc, work_table: &Arc, @@ -496,7 +541,6 @@ fn new_groups_mask( #[cfg(test)] mod tests { use super::*; - use crate::common::SchemaAlignExec; use crate::empty::EmptyExec; use crate::projection::ProjectionExec; @@ -538,23 +582,22 @@ mod tests { } #[test] - fn recursive_query_exec_preserves_static_nullability_contract() -> Result<()> { + fn recursive_query_exec_widens_output_nullability_from_recursive_term() -> Result<()> { let static_term = empty_exec(vec![Field::new("value", DataType::Int32, false)]); let recursive_term = empty_exec(vec![Field::new("value + Int32(1)", DataType::Int32, true)]); let exec = recursive_exec(Arc::clone(&static_term), Arc::clone(&recursive_term))?; - let static_schema = static_term.schema(); - assert_eq!(exec.schema(), static_schema); - assert_eq!(exec.static_term().schema(), static_schema); - assert_eq!(exec.recursive_term().schema(), static_schema); - assert!(!exec.schema().field(0).is_nullable()); - let aligned = exec - .recursive_term() - .downcast_ref::() - .expect("nullable recursive term should be aligned with SchemaAlignExec"); - assert!(Arc::ptr_eq(aligned.input(), &recursive_term)); + let expected_schema = Arc::new(Schema::new(vec![Field::new( + "value", + DataType::Int32, + true, + )])); + assert_eq!(exec.schema(), expected_schema); + assert_eq!(exec.static_term().schema(), expected_schema); + assert_eq!(exec.recursive_term().schema(), expected_schema); + assert!(exec.schema().field(0).is_nullable()); Ok(()) } } diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 7ae5cbeed3e53..b245ba235a964 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -1067,12 +1067,12 @@ impl AsLogicalPlan for LogicalPlanNode { ))? .try_into_logical_plan(ctx, extension_codec)?; - Ok(LogicalPlan::RecursiveQuery(RecursiveQuery { - name: recursive_query_node.name.clone(), - static_term: Arc::new(static_term), - recursive_term: Arc::new(recursive_term), - is_distinct: recursive_query_node.is_distinct, - })) + Ok(LogicalPlan::RecursiveQuery(RecursiveQuery::try_new( + recursive_query_node.name.clone(), + Arc::new(static_term), + Arc::new(recursive_term), + recursive_query_node.is_distinct, + )?)) } LogicalPlanType::CteWorkTableScan(cte_work_table_scan_node) => { let CteWorkTableScanNode { name, schema } = cte_work_table_scan_node; diff --git a/datafusion/sqllogictest/test_files/cte.slt b/datafusion/sqllogictest/test_files/cte.slt index bb5a18d53d82d..e16f96cdd44b7 100644 --- a/datafusion/sqllogictest/test_files/cte.slt +++ b/datafusion/sqllogictest/test_files/cte.slt @@ -1300,6 +1300,19 @@ DROP TABLE cte_schema_reread; statement ok DROP TABLE cte_schema_records; +# Recursive CTE nullability is union-like: anchor names are preserved, +# but nullable recursive output widens the CTE output schema. +query I +WITH RECURSIVE t AS ( + SELECT 0 AS n + UNION ALL + SELECT CAST(NULL AS INT) AS n FROM t WHERE n IS NOT NULL +) +SELECT * FROM t; +---- +0 +NULL + statement count 0 set datafusion.execution.enable_recursive_ctes = false; From 0b35795bb6519787de228285c472e473d40552b8 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Sun, 10 May 2026 22:31:33 +0800 Subject: [PATCH 07/24] fix: optimize recursive CTE handling to prevent SLT hang - Addressed issue with the work table being planned with anchor/static schema only. - Modified logic to ensure that recursive term is planned once with anchor schema, preventing non-null optimizations that lead to infinite NULL emissions. - Built initial recursive CTE schema and recreated work table if schema nullability widened. - Replanned the recursive term using the widened work table schema to avoid inefficiencies. --- datafusion/sql/src/cte.rs | 36 +++++++++++++++++++++++++++++++----- 1 file changed, 31 insertions(+), 5 deletions(-) diff --git a/datafusion/sql/src/cte.rs b/datafusion/sql/src/cte.rs index 18766d7056355..a0b08e9ffdc43 100644 --- a/datafusion/sql/src/cte.rs +++ b/datafusion/sql/src/cte.rs @@ -159,7 +159,7 @@ impl SqlToRel<'_, S> { // this uses the named_relation we inserted above to resolve the // relation. This ensures that the recursive term uses the named relation logical plan // and thus the 'continuance' physical plan as its input and source - let recursive_plan = self.set_expr_to_plan(*right_expr, planner_context)?; + let recursive_plan = self.set_expr_to_plan(*right_expr.clone(), planner_context)?; // Check if the recursive term references the CTE itself, // if not, it is a non-recursive CTE @@ -176,11 +176,37 @@ impl SqlToRel<'_, S> { } // ---------- Step 4: Create the final plan ------------------ - // Step 4.1: Compile the final plan + // Step 4.1: Compile the final plan. Recursive CTE nullability is + // union-like, so the recursive term can widen the work table schema. + // Replan the recursive term with that widened schema so predicates such + // as `n IS NOT NULL` are not optimized using the anchor-only schema. let distinct = !Self::is_union_all(set_quantifier)?; - LogicalPlanBuilder::from(static_plan) - .to_recursive_query(name, recursive_plan, distinct)? - .build() + let initial_recursive_query = LogicalPlanBuilder::from(static_plan.clone()) + .to_recursive_query(name.clone(), recursive_plan.clone(), distinct)? + .build()?; + if initial_recursive_query.schema() != static_plan.schema() { + let work_table_source = self + .context_provider + .create_cte_work_table( + cte_name, + Arc::clone(initial_recursive_query.schema().inner()), + )?; + let work_table_plan = LogicalPlanBuilder::scan( + cte_name.to_string(), + Arc::clone(&work_table_source), + None, + )? + .build()?; + planner_context.insert_cte(cte_name.to_string(), work_table_plan); + let recursive_plan = self.set_expr_to_plan(*right_expr, planner_context)?; + planner_context.remove_cte(cte_name); + LogicalPlanBuilder::from(static_plan) + .to_recursive_query(name, recursive_plan, distinct)? + .build() + } else { + planner_context.remove_cte(cte_name); + Ok(initial_recursive_query) + } } } From 0bdf95c7620032b9c227b02b7141a23829ce4ef2 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Sun, 10 May 2026 22:43:39 +0800 Subject: [PATCH 08/24] feat: refactor CTE handling and logical plan simplifications - Added private `cte_work_table_plan` in `cte.rs` - Removed duplicated work-table source/scan construction in `cte.rs` - Simplified `recursive_query_schema` in `plan.rs` - Removed unnecessary Result wrapping in field collection in `plan.rs` - Used `Field::with_metadata` in `plan.rs` - Updated stale comment and used `Field::with_metadata` in `recursive_query.rs` --- datafusion/expr/src/logical_plan/builder.rs | 4 +- datafusion/expr/src/logical_plan/plan.rs | 13 ++--- datafusion/expr/src/logical_plan/tree_node.rs | 9 +++- .../physical-plan/src/recursive_query.rs | 24 +++++---- datafusion/sql/src/cte.rs | 52 ++++++++++--------- 5 files changed, 56 insertions(+), 46 deletions(-) diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 745ffe1b6fb19..888b252d2cbcc 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -2371,8 +2371,8 @@ mod tests { #[test] fn recursive_query_schema_widens_nullability_from_recursive_term() -> Result<()> { - let static_term = LogicalPlanBuilder::empty(true) - .project(vec![lit(0i32).alias("n")])?; + let static_term = + LogicalPlanBuilder::empty(true).project(vec![lit(0i32).alias("n")])?; let recursive_term = LogicalPlanBuilder::empty(true) .project(vec![lit(ScalarValue::Int32(None)).alias("recursive_n")])? .build()?; diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index d4fa2f2c145ff..3421c1f007f67 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -2269,7 +2269,8 @@ impl RecursiveQuery { recursive_term: Arc, is_distinct: bool, ) -> Result { - let schema = recursive_query_schema(static_term.schema(), recursive_term.schema())?; + let schema = + recursive_query_schema(static_term.schema(), recursive_term.schema())?; Ok(Self { name, static_term, @@ -2291,18 +2292,18 @@ fn recursive_query_schema( .enumerate() .map(|(i, (static_field, recursive_field))| { let (qualifier, _) = static_schema.qualified_field(i); - let mut field = Field::new( + let field = Field::new( static_field.name(), static_field.data_type().clone(), static_field.is_nullable() || recursive_field.is_nullable(), - ); - field.set_metadata(intersect_metadata_for_union([ + ) + .with_metadata(intersect_metadata_for_union([ static_field.metadata(), recursive_field.metadata(), ])); - Ok((qualifier.cloned(), Arc::new(field))) + (qualifier.cloned(), Arc::new(field)) }) - .collect::>>()?; + .collect::>(); let metadata = intersect_metadata_for_union([ static_schema.metadata(), diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index 0385637edd576..0e54e4536439c 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -332,8 +332,13 @@ impl TreeNode for LogicalPlan { .. }) => (static_term, recursive_term).map_elements(f)?.map_data( |(static_term, recursive_term)| { - RecursiveQuery::try_new(name, static_term, recursive_term, is_distinct) - .map(LogicalPlan::RecursiveQuery) + RecursiveQuery::try_new( + name, + static_term, + recursive_term, + is_distinct, + ) + .map(LogicalPlan::RecursiveQuery) }, )?, LogicalPlan::Statement(stmt) => match stmt { diff --git a/datafusion/physical-plan/src/recursive_query.rs b/datafusion/physical-plan/src/recursive_query.rs index 062774f90a796..3753c1ece681b 100644 --- a/datafusion/physical-plan/src/recursive_query.rs +++ b/datafusion/physical-plan/src/recursive_query.rs @@ -90,9 +90,9 @@ impl RecursiveQueryExec { ) -> Result { // Each recursive query needs its own work table let work_table = Arc::new(WorkTable::new(name.clone())); - // Use the static term as the declared recursive CTE output schema. The - // recursive term is planned independently, so align it at plan construction - // time instead of patching batches in RecursiveQueryStream. + // Use static term field names with nullability widened across static and + // recursive terms. Align both children at plan construction time instead + // of patching batches in RecursiveQueryStream. let recursive_term = assign_work_table(recursive_term, &work_table)?; let output_schema = recursive_query_output_schema( &static_term.schema(), @@ -399,13 +399,14 @@ fn recursive_query_output_schema( recursive_field.data_type() ); } - let mut field = Field::new( - static_field.name(), - static_field.data_type().clone(), - static_field.is_nullable() || recursive_field.is_nullable(), - ); - field.set_metadata(static_field.metadata().clone()); - Ok(Arc::new(field)) + Ok(Arc::new( + Field::new( + static_field.name(), + static_field.data_type().clone(), + static_field.is_nullable() || recursive_field.is_nullable(), + ) + .with_metadata(static_field.metadata().clone()), + )) }) .collect::>>()?; @@ -582,7 +583,8 @@ mod tests { } #[test] - fn recursive_query_exec_widens_output_nullability_from_recursive_term() -> Result<()> { + fn recursive_query_exec_widens_output_nullability_from_recursive_term() -> Result<()> + { let static_term = empty_exec(vec![Field::new("value", DataType::Int32, false)]); let recursive_term = empty_exec(vec![Field::new("value + Int32(1)", DataType::Int32, true)]); diff --git a/datafusion/sql/src/cte.rs b/datafusion/sql/src/cte.rs index a0b08e9ffdc43..32c2ee2d96528 100644 --- a/datafusion/sql/src/cte.rs +++ b/datafusion/sql/src/cte.rs @@ -19,6 +19,7 @@ use std::sync::Arc; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; +use arrow::datatypes::SchemaRef; use datafusion_common::{ Result, not_impl_err, plan_err, tree_node::{TreeNode, TreeNodeRecursion}, @@ -132,19 +133,10 @@ impl SqlToRel<'_, S> { // bound to. // ---------- Step 2: Create a temporary relation ------------------ - // Step 2.1: Create a table source for the temporary relation - let work_table_source = self - .context_provider - .create_cte_work_table(cte_name, Arc::clone(static_plan.schema().inner()))?; - - // Step 2.2: Create a temporary relation logical plan that will be used + // Step 2.1: Create a temporary relation logical plan that will be used // as the input to the recursive term - let work_table_plan = LogicalPlanBuilder::scan( - cte_name.to_string(), - Arc::clone(&work_table_source), - None, - )? - .build()?; + let (work_table_source, work_table_plan) = + self.cte_work_table_plan(cte_name, Arc::clone(static_plan.schema().inner()))?; let name = cte_name.to_string(); @@ -159,7 +151,8 @@ impl SqlToRel<'_, S> { // this uses the named_relation we inserted above to resolve the // relation. This ensures that the recursive term uses the named relation logical plan // and thus the 'continuance' physical plan as its input and source - let recursive_plan = self.set_expr_to_plan(*right_expr.clone(), planner_context)?; + let recursive_plan = + self.set_expr_to_plan(*right_expr.clone(), planner_context)?; // Check if the recursive term references the CTE itself, // if not, it is a non-recursive CTE @@ -185,18 +178,10 @@ impl SqlToRel<'_, S> { .to_recursive_query(name.clone(), recursive_plan.clone(), distinct)? .build()?; if initial_recursive_query.schema() != static_plan.schema() { - let work_table_source = self - .context_provider - .create_cte_work_table( - cte_name, - Arc::clone(initial_recursive_query.schema().inner()), - )?; - let work_table_plan = LogicalPlanBuilder::scan( - cte_name.to_string(), - Arc::clone(&work_table_source), - None, - )? - .build()?; + let (_, work_table_plan) = self.cte_work_table_plan( + cte_name, + Arc::clone(initial_recursive_query.schema().inner()), + )?; planner_context.insert_cte(cte_name.to_string(), work_table_plan); let recursive_plan = self.set_expr_to_plan(*right_expr, planner_context)?; planner_context.remove_cte(cte_name); @@ -208,6 +193,23 @@ impl SqlToRel<'_, S> { Ok(initial_recursive_query) } } + + fn cte_work_table_plan( + &self, + cte_name: &str, + schema: SchemaRef, + ) -> Result<(Arc, LogicalPlan)> { + let work_table_source = self + .context_provider + .create_cte_work_table(cte_name, schema)?; + let work_table_plan = LogicalPlanBuilder::scan( + cte_name.to_string(), + Arc::clone(&work_table_source), + None, + )? + .build()?; + Ok((work_table_source, work_table_plan)) + } } fn has_work_table_reference( From 63f62a8295b6e729a803eda528e10091d0cc4e99 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Sun, 10 May 2026 23:08:40 +0800 Subject: [PATCH 09/24] feat: enhance recursive query validation and testing - Updated `RecursiveQuery::try_new` to validate column count and data types. - Added direct regression tests for logical plan. - Enhanced physical recursive schema to intersect field/schema metadata like logical schema. - Implemented metadata regression test in physical plan. - Improved `align_plan_to_schema` to align metadata via `SchemaAlignExec`. - Maintained behavior in `project_plan_to_schema` to reject metadata changes. - Added comment for projection-error fallback in common code. - Clarified comments regarding two-pass recursive planning in SQL component. --- datafusion/expr/src/logical_plan/plan.rs | 58 +++++++++++++++- datafusion/physical-plan/src/common.rs | 67 +++++++++++++------ .../physical-plan/src/recursive_query.rs | 61 ++++++++++++++++- datafusion/sql/src/cte.rs | 3 +- 4 files changed, 164 insertions(+), 25 deletions(-) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 3421c1f007f67..9960de040099a 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -2285,12 +2285,27 @@ fn recursive_query_schema( static_schema: &DFSchema, recursive_schema: &DFSchema, ) -> Result { + if static_schema.fields().len() != recursive_schema.fields().len() { + return plan_err!( + "RecursiveQuery static and recursive terms have different number of columns: {} != {}", + static_schema.fields().len(), + recursive_schema.fields().len() + ); + } + let fields = static_schema .fields() .iter() .zip(recursive_schema.fields().iter()) .enumerate() .map(|(i, (static_field, recursive_field))| { + if static_field.data_type() != recursive_field.data_type() { + return plan_err!( + "RecursiveQuery column {i} has different types: static term has {} whereas recursive term has {}", + static_field.data_type(), + recursive_field.data_type() + ); + } let (qualifier, _) = static_schema.qualified_field(i); let field = Field::new( static_field.name(), @@ -2301,9 +2316,9 @@ fn recursive_query_schema( static_field.metadata(), recursive_field.metadata(), ])); - (qualifier.cloned(), Arc::new(field)) + Ok((qualifier.cloned(), Arc::new(field))) }) - .collect::>(); + .collect::>>()?; let metadata = intersect_metadata_for_union([ static_schema.metadata(), @@ -4942,6 +4957,45 @@ mod tests { ); } + fn empty_plan_with_fields(fields: Vec) -> Arc { + Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: Arc::new( + DFSchema::from_unqualified_fields(fields.into(), HashMap::new()).unwrap(), + ), + })) + } + + #[test] + fn recursive_query_try_new_rejects_mismatched_column_count() { + let static_term = + empty_plan_with_fields(vec![Field::new("a", DataType::Int32, false)]); + let recursive_term = empty_plan_with_fields(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ]); + + let err = + RecursiveQuery::try_new("t".to_string(), static_term, recursive_term, false) + .unwrap_err(); + + assert_snapshot!(err.strip_backtrace(), @"Error during planning: RecursiveQuery static and recursive terms have different number of columns: 1 != 2"); + } + + #[test] + fn recursive_query_try_new_rejects_mismatched_types() { + let static_term = + empty_plan_with_fields(vec![Field::new("a", DataType::Int32, false)]); + let recursive_term = + empty_plan_with_fields(vec![Field::new("a", DataType::Int64, false)]); + + let err = + RecursiveQuery::try_new("t".to_string(), static_term, recursive_term, false) + .unwrap_err(); + + assert_snapshot!(err.strip_backtrace(), @"Error during planning: RecursiveQuery column 0 has different types: static term has Int32 whereas recursive term has Int64"); + } + #[test] fn test_partial_eq_hash_and_partial_ord() { let empty_values = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { diff --git a/datafusion/physical-plan/src/common.rs b/datafusion/physical-plan/src/common.rs index b3ab903f02228..2e5848b1097e2 100644 --- a/datafusion/physical-plan/src/common.rs +++ b/datafusion/physical-plan/src/common.rs @@ -107,9 +107,8 @@ fn build_file_list_recurse( /// is visible in the physical plan, while batch schema rebinding remains contained in the /// adapter as the implementation detail required to uphold the plan-level schema contract. /// -/// This helper can align field names and nullability to the declared schema. It rejects -/// differences that would change values or silently lose schema information, such as column -/// count, data type, field metadata, or schema metadata mismatches. +/// This helper can align field names, nullability, and metadata to the declared schema. It +/// rejects differences that would change values, such as column count or data type mismatches. /// /// When an adapter is required, it conservatively derives fresh equivalence properties from /// `expected_schema` and drops child hash partitioning because field names/nullability may have @@ -124,6 +123,10 @@ pub fn align_plan_to_schema( return Ok(input); } + // Projection is the preferred adapter, but not every valid schema-only + // alignment can be represented by ProjectionExec (for example nullability + // narrowing). Treat projection errors as path-selection only; if the + // fallback also fails, SchemaAlignExec returns the final diagnostic. if let Ok(projected) = project_plan_to_schema(Arc::clone(&input), expected_schema) { debug_assert_eq!(projected.schema().as_ref(), expected_schema.as_ref()); return Ok(projected); @@ -138,8 +141,8 @@ pub fn align_plan_to_schema( /// Project `input` to `expected_schema` when [`ProjectionExec`] can produce that exact schema. /// /// This is a narrower helper than [`align_plan_to_schema`]. It is useful when a positional -/// projection/alias is sufficient. It rejects requests where projection cannot advertise the -/// exact expected schema, such as nullability narrowing. +/// projection/alias is sufficient. It rejects requests where ProjectionExec cannot advertise the +/// exact expected schema, such as nullability narrowing or metadata changes. pub fn project_plan_to_schema( input: Arc, expected_schema: &SchemaRef, @@ -151,6 +154,36 @@ pub fn project_plan_to_schema( return Ok(input); } + if input_schema.metadata() != expected_schema.metadata() { + return plan_err!( + "Cannot project plan to expected schema: schema metadata differ" + ); + } + + if let Some((i, input_field, expected_field, mismatch)) = input_schema + .fields() + .iter() + .zip(expected_schema.fields().iter()) + .enumerate() + .find_map(|(i, (input_field, expected_field))| { + (input_field.metadata() != expected_field.metadata()).then_some(( + i, + input_field, + expected_field, + "metadata", + )) + }) + { + return plan_err!( + "Cannot project plan column {i} ('{}') to expected output field '{}': \ + field {mismatch} differs (input field: {:?}, expected field: {:?})", + input_field.name(), + expected_field.name(), + input_field, + expected_field + ); + } + if let Some((i, input_field, expected_field)) = input_schema .fields() .iter() @@ -215,12 +248,6 @@ fn validate_schema_alignment( ); } - if input_schema.metadata() != expected_schema.metadata() { - return plan_err!( - "Cannot {operation} plan to expected schema: schema metadata differ" - ); - } - if let Some((i, input_field, expected_field, mismatch)) = input_schema .fields() .iter() @@ -229,8 +256,6 @@ fn validate_schema_alignment( .find_map(|(i, (input_field, expected_field))| { if input_field.data_type() != expected_field.data_type() { Some((i, input_field, expected_field, "data type")) - } else if input_field.metadata() != expected_field.metadata() { - Some((i, input_field, expected_field, "metadata")) } else { None } @@ -804,19 +829,23 @@ mod tests { } #[test] - fn align_plan_to_schema_errors_on_field_metadata_mismatch() { + fn align_plan_to_schema_aligns_field_metadata() -> Result<()> { let (input, expected_schema) = field_metadata_mismatch(); - let err = align_plan_to_schema(input, &expected_schema).unwrap_err(); - assert!(err.to_string().contains("field metadata differs")); + let result = align_plan_to_schema(input, &expected_schema)?; + + assert_eq!(result.schema(), expected_schema); + Ok(()) } #[test] - fn align_plan_to_schema_errors_on_schema_metadata_mismatch() { + fn align_plan_to_schema_aligns_schema_metadata() -> Result<()> { let (input, expected_schema) = schema_metadata_mismatch(); - let err = align_plan_to_schema(input, &expected_schema).unwrap_err(); - assert!(err.to_string().contains("schema metadata differ")); + let result = align_plan_to_schema(input, &expected_schema)?; + + assert_eq!(result.schema(), expected_schema); + Ok(()) } #[test] diff --git a/datafusion/physical-plan/src/recursive_query.rs b/datafusion/physical-plan/src/recursive_query.rs index 3753c1ece681b..2bb32c39cf3d8 100644 --- a/datafusion/physical-plan/src/recursive_query.rs +++ b/datafusion/physical-plan/src/recursive_query.rs @@ -42,6 +42,7 @@ use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{Result, internal_datafusion_err, not_impl_err}; use datafusion_execution::TaskContext; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; +use datafusion_expr::expr::intersect_metadata_for_union; use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; @@ -405,14 +406,20 @@ fn recursive_query_output_schema( static_field.data_type().clone(), static_field.is_nullable() || recursive_field.is_nullable(), ) - .with_metadata(static_field.metadata().clone()), + .with_metadata(intersect_metadata_for_union([ + static_field.metadata(), + recursive_field.metadata(), + ])), )) }) .collect::>>()?; Ok(Arc::new(Schema::new_with_metadata( fields, - static_schema.metadata().clone(), + intersect_metadata_for_union([ + static_schema.metadata(), + recursive_schema.metadata(), + ]), ))) } @@ -546,9 +553,14 @@ mod tests { use crate::projection::ProjectionExec; use arrow::datatypes::{DataType, Field, Schema}; + use std::collections::HashMap; fn empty_exec(fields: Vec) -> Arc { - Arc::new(EmptyExec::new(Arc::new(Schema::new(fields)))) + empty_exec_with_schema(Arc::new(Schema::new(fields))) + } + + fn empty_exec_with_schema(schema: SchemaRef) -> Arc { + Arc::new(EmptyExec::new(schema)) } fn recursive_exec( @@ -602,4 +614,47 @@ mod tests { assert!(exec.schema().field(0).is_nullable()); Ok(()) } + + #[test] + fn recursive_query_exec_intersects_output_metadata() -> Result<()> { + let static_field = + Field::new("value", DataType::Int32, false).with_metadata(HashMap::from([ + ("shared".to_string(), "same".to_string()), + ("static".to_string(), "only".to_string()), + ])); + let recursive_field = + Field::new("value", DataType::Int32, false).with_metadata(HashMap::from([ + ("shared".to_string(), "same".to_string()), + ("static".to_string(), "different".to_string()), + ])); + let static_schema = Arc::new(Schema::new_with_metadata( + vec![static_field], + HashMap::from([ + ("shared".to_string(), "same".to_string()), + ("static".to_string(), "only".to_string()), + ]), + )); + let recursive_schema = Arc::new(Schema::new_with_metadata( + vec![recursive_field], + HashMap::from([ + ("shared".to_string(), "same".to_string()), + ("static".to_string(), "different".to_string()), + ]), + )); + + let exec = recursive_exec( + empty_exec_with_schema(static_schema), + empty_exec_with_schema(recursive_schema), + )?; + + assert_eq!( + exec.schema().field(0).metadata(), + &HashMap::from([("shared".to_string(), "same".to_string())]) + ); + assert_eq!( + exec.schema().metadata(), + &HashMap::from([("shared".to_string(), "same".to_string())]) + ); + Ok(()) + } } diff --git a/datafusion/sql/src/cte.rs b/datafusion/sql/src/cte.rs index 32c2ee2d96528..74c92258889cc 100644 --- a/datafusion/sql/src/cte.rs +++ b/datafusion/sql/src/cte.rs @@ -169,7 +169,8 @@ impl SqlToRel<'_, S> { } // ---------- Step 4: Create the final plan ------------------ - // Step 4.1: Compile the final plan. Recursive CTE nullability is + // Step 4.1: Compile the final plan. The first plan only discovers the + // fixed recursive CTE output schema. Recursive CTE nullability is // union-like, so the recursive term can widen the work table schema. // Replan the recursive term with that widened schema so predicates such // as `n IS NOT NULL` are not optimized using the anchor-only schema. From 8b4786f367a5664f9e896fb04ecf745fa72809b6 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Mon, 11 May 2026 08:49:38 +0800 Subject: [PATCH 10/24] feat: enhance recursive CTE handling in DataFusion - Updated RecursiveQueryExec to accept declared logical recursive CTE schema. - Removed physical recursive schema recomputation, using logical schema as source of truth. - Aligned children to declared schema. - Introduced private recursive-CTE-local schema rebind exec for metadata/name/schema-only fixes. - Eliminated broad global align_plan_to_schema and SchemaAlignExec, retaining narrower project_plan_to_schema. --- datafusion/core/src/physical_planner.rs | 6 +- datafusion/physical-plan/src/common.rs | 330 ++---------------- .../physical-plan/src/recursive_query.rs | 227 +++++++++--- 3 files changed, 204 insertions(+), 359 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 3b2c7a78e898e..7304c197c6602 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -1777,13 +1777,17 @@ impl DefaultPhysicalPlanner { } } LogicalPlan::RecursiveQuery(RecursiveQuery { - name, is_distinct, .. + name, + schema, + is_distinct, + .. }) => { let [static_term, recursive_term] = children.two()?; Arc::new(RecursiveQueryExec::try_new( name.clone(), static_term, recursive_term, + Arc::clone(schema.inner()), *is_distinct, )?) } diff --git a/datafusion/physical-plan/src/common.rs b/datafusion/physical-plan/src/common.rs index 2e5848b1097e2..5c29966a1f719 100644 --- a/datafusion/physical-plan/src/common.rs +++ b/datafusion/physical-plan/src/common.rs @@ -24,21 +24,15 @@ use std::sync::Arc; use super::SendableRecordBatchStream; use crate::expressions::{CastExpr, Column}; use crate::projection::{ProjectionExec, ProjectionExpr}; -use crate::stream::{RecordBatchReceiverStream, RecordBatchStreamAdapter}; -use crate::{ - ColumnStatistics, DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, - PlanProperties, Statistics, -}; +use crate::stream::RecordBatchReceiverStream; +use crate::{ColumnStatistics, ExecutionPlan, Statistics}; use arrow::array::Array; use arrow::datatypes::{Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_common::stats::Precision; -use datafusion_common::tree_node::TreeNodeRecursion; use datafusion_common::{Result, plan_err}; -use datafusion_execution::TaskContext; use datafusion_execution::memory_pool::MemoryReservation; -use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr}; use futures::{StreamExt, TryStreamExt}; @@ -96,60 +90,32 @@ fn build_file_list_recurse( Ok(()) } -/// Align `input`'s physical plan schema with `expected_schema`. +/// Project `input` to `expected_schema` when [`ProjectionExec`] can produce that exact schema. /// /// This helper is intended for operators that combine independently planned children but /// expose a single declared output schema. It returns `input` unchanged when schemas already -/// match exactly. Otherwise, it validates positional compatibility and uses a plan-time -/// adapter whose advertised and emitted schema is exactly `expected_schema`. -/// -/// Prefer this helper over rebinding batches inside a parent operator's stream. The alignment -/// is visible in the physical plan, while batch schema rebinding remains contained in the -/// adapter as the implementation detail required to uphold the plan-level schema contract. -/// -/// This helper can align field names, nullability, and metadata to the declared schema. It -/// rejects differences that would change values, such as column count or data type mismatches. +/// match exactly. Otherwise, it validates that projection can safely produce the expected +/// schema, then wraps `input` in a [`ProjectionExec`] that keeps columns in their existing +/// positional order and aliases them to `expected_schema`'s field names. /// -/// When an adapter is required, it conservatively derives fresh equivalence properties from -/// `expected_schema` and drops child hash partitioning because field names/nullability may have -/// changed while the underlying partitioning expressions still refer to the child schema. -pub fn align_plan_to_schema( +/// [`ProjectionExec`] can rename fields. When the expected field is nullable and the input +/// field is not, this helper also widens nullability with a same-type [`CastExpr`]. It rejects +/// differences that projection cannot safely normalize exactly, such as data type, metadata, +/// schema metadata, and nullability narrowing. +pub fn project_plan_to_schema( input: Arc, expected_schema: &SchemaRef, ) -> Result> { let input_schema = input.schema(); - if input_schema.as_ref() == expected_schema.as_ref() { - return Ok(input); - } - - // Projection is the preferred adapter, but not every valid schema-only - // alignment can be represented by ProjectionExec (for example nullability - // narrowing). Treat projection errors as path-selection only; if the - // fallback also fails, SchemaAlignExec returns the final diagnostic. - if let Ok(projected) = project_plan_to_schema(Arc::clone(&input), expected_schema) { - debug_assert_eq!(projected.schema().as_ref(), expected_schema.as_ref()); - return Ok(projected); + if input_schema.fields().len() != expected_schema.fields().len() { + return plan_err!( + "Cannot project plan to expected schema: expected {} column(s), got {}", + expected_schema.fields().len(), + input_schema.fields().len() + ); } - Ok(Arc::new(SchemaAlignExec::try_new( - input, - Arc::clone(expected_schema), - )?)) -} - -/// Project `input` to `expected_schema` when [`ProjectionExec`] can produce that exact schema. -/// -/// This is a narrower helper than [`align_plan_to_schema`]. It is useful when a positional -/// projection/alias is sufficient. It rejects requests where ProjectionExec cannot advertise the -/// exact expected schema, such as nullability narrowing or metadata changes. -pub fn project_plan_to_schema( - input: Arc, - expected_schema: &SchemaRef, -) -> Result> { - let input_schema = input.schema(); - validate_schema_alignment(&input_schema, expected_schema, "project")?; - if input_schema.as_ref() == expected_schema.as_ref() { return Ok(input); } @@ -166,12 +132,13 @@ pub fn project_plan_to_schema( .zip(expected_schema.fields().iter()) .enumerate() .find_map(|(i, (input_field, expected_field))| { - (input_field.metadata() != expected_field.metadata()).then_some(( - i, - input_field, - expected_field, - "metadata", - )) + if input_field.data_type() != expected_field.data_type() { + Some((i, input_field, expected_field, "data type")) + } else if input_field.metadata() != expected_field.metadata() { + Some((i, input_field, expected_field, "metadata")) + } else { + None + } }) { return plan_err!( @@ -235,173 +202,6 @@ pub fn project_plan_to_schema( Ok(Arc::new(projection)) } -fn validate_schema_alignment( - input_schema: &SchemaRef, - expected_schema: &SchemaRef, - operation: &str, -) -> Result<()> { - if input_schema.fields().len() != expected_schema.fields().len() { - return plan_err!( - "Cannot {operation} plan to expected schema: expected {} column(s), got {}", - expected_schema.fields().len(), - input_schema.fields().len() - ); - } - - if let Some((i, input_field, expected_field, mismatch)) = input_schema - .fields() - .iter() - .zip(expected_schema.fields().iter()) - .enumerate() - .find_map(|(i, (input_field, expected_field))| { - if input_field.data_type() != expected_field.data_type() { - Some((i, input_field, expected_field, "data type")) - } else { - None - } - }) - { - return plan_err!( - "Cannot {operation} plan column {i} ('{}') to expected output field '{}': \ - field {mismatch} differs (input field: {:?}, expected field: {:?})", - input_field.name(), - expected_field.name(), - input_field, - expected_field - ); - } - - Ok(()) -} - -/// Plan-time schema adapter for positional schema alignment. -/// -/// [`ProjectionExec`] cannot express every schema-only alignment. In particular, a column -/// expression remains nullable when its input field is nullable, so projection cannot advertise -/// a non-null expected field. This adapter is for cases where the operator-level contract has -/// already established that columns are positionally compatible and the child plan must expose -/// the declared schema exactly. -#[derive(Debug, Clone)] -pub struct SchemaAlignExec { - input: Arc, - schema: SchemaRef, - cache: Arc, -} - -impl SchemaAlignExec { - /// Create a new schema alignment adapter. - pub fn try_new(input: Arc, schema: SchemaRef) -> Result { - validate_schema_alignment(&input.schema(), &schema, "align")?; - - let input_properties = input.properties(); - let partitioning = match &input_properties.partitioning { - Partitioning::RoundRobinBatch(partitions) => { - Partitioning::RoundRobinBatch(*partitions) - } - partitioning => { - Partitioning::UnknownPartitioning(partitioning.partition_count()) - } - }; - let properties = PlanProperties::new( - EquivalenceProperties::new(Arc::clone(&schema)), - partitioning, - input_properties.emission_type, - input_properties.boundedness, - ) - .with_evaluation_type(input_properties.evaluation_type) - .with_scheduling_type(input_properties.scheduling_type); - - Ok(Self { - input, - schema, - cache: Arc::new(properties), - }) - } - - /// Input plan being aligned. - pub fn input(&self) -> &Arc { - &self.input - } -} - -impl DisplayAs for SchemaAlignExec { - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - match t { - DisplayFormatType::Default | DisplayFormatType::Verbose => { - write!(f, "SchemaAlignExec") - } - DisplayFormatType::TreeRender => Ok(()), - } - } -} - -impl ExecutionPlan for SchemaAlignExec { - fn name(&self) -> &'static str { - "SchemaAlignExec" - } - - fn properties(&self) -> &Arc { - &self.cache - } - - fn children(&self) -> Vec<&Arc> { - vec![&self.input] - } - - fn apply_expressions( - &self, - _f: &mut dyn FnMut(&dyn PhysicalExpr) -> Result, - ) -> Result { - Ok(TreeNodeRecursion::Continue) - } - - fn maintains_input_order(&self) -> Vec { - vec![true] - } - - fn with_new_children( - self: Arc, - children: Vec>, - ) -> Result> { - let [input] = children.try_into().map_err(|children: Vec<_>| { - datafusion_common::DataFusionError::Internal(format!( - "SchemaAlignExec expected 1 child, got {}", - children.len() - )) - })?; - Ok(Arc::new(Self::try_new(input, Arc::clone(&self.schema))?)) - } - - fn execute( - &self, - partition: usize, - context: Arc, - ) -> Result { - let schema = Arc::clone(&self.schema); - let stream = self.input.execute(partition, context)?.map({ - let schema = Arc::clone(&schema); - move |batch| { - let batch = batch?; - if batch.schema().as_ref() == schema.as_ref() { - Ok(batch) - } else { - RecordBatch::try_new(Arc::clone(&schema), batch.columns().to_vec()) - .map_err(Into::into) - } - } - }); - Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream))) - } - - fn partition_statistics(&self, partition: Option) -> Result> { - self.input.partition_statistics(partition) - } -} - /// If running in a tokio context spawns the execution of `stream` to a separate task /// allowing it to execute in parallel with an intermediate buffer of size `buffer` pub fn spawn_buffered( @@ -766,88 +566,6 @@ mod tests { assert!(err.to_string().contains("field nullability differs")); } - #[test] - fn align_plan_to_schema_returns_input_when_schema_matches() -> Result<()> { - let schema = single_field_schema("value", DataType::Int32, false); - let input: Arc = Arc::new(EmptyExec::new(Arc::clone(&schema))); - - let result = align_plan_to_schema(Arc::clone(&input), &schema)?; - - assert!(Arc::ptr_eq(&input, &result)); - Ok(()) - } - - #[test] - fn align_plan_to_schema_uses_projection_for_rename_only() -> Result<()> { - let input = single_i32_exec("recursive_a", false); - let expected_schema = single_field_schema("a", DataType::Int32, false); - - let result = align_plan_to_schema(Arc::clone(&input), &expected_schema)?; - - let projection = result - .downcast_ref::() - .expect("rename-only alignment should use ProjectionExec"); - assert!(Arc::ptr_eq(projection.input(), &input)); - assert_eq!(projection.schema(), expected_schema); - Ok(()) - } - - #[test] - fn align_plan_to_schema_uses_adapter_for_nullability_narrowing() -> Result<()> { - let input = single_i32_exec("a", true); - let expected_schema = single_field_schema("renamed", DataType::Int32, false); - - let result = align_plan_to_schema(Arc::clone(&input), &expected_schema)?; - - let aligned = result - .downcast_ref::() - .expect("nullability narrowing should use SchemaAlignExec"); - assert!(Arc::ptr_eq(aligned.input(), &input)); - assert_eq!(aligned.schema(), expected_schema); - Ok(()) - } - - #[test] - fn align_plan_to_schema_errors_on_column_count_mismatch() { - let input = single_i32_exec("a", false); - let expected_schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Int32, false), - ])); - - let err = align_plan_to_schema(input, &expected_schema).unwrap_err(); - assert!(err.to_string().contains("expected 2 column")); - } - - #[test] - fn align_plan_to_schema_errors_on_type_mismatch() { - let input = single_i32_exec("a", false); - let expected_schema = single_field_schema("a", DataType::Float32, false); - - let err = align_plan_to_schema(input, &expected_schema).unwrap_err(); - assert!(err.to_string().contains("field data type differs")); - } - - #[test] - fn align_plan_to_schema_aligns_field_metadata() -> Result<()> { - let (input, expected_schema) = field_metadata_mismatch(); - - let result = align_plan_to_schema(input, &expected_schema)?; - - assert_eq!(result.schema(), expected_schema); - Ok(()) - } - - #[test] - fn align_plan_to_schema_aligns_schema_metadata() -> Result<()> { - let (input, expected_schema) = schema_metadata_mismatch(); - - let result = align_plan_to_schema(input, &expected_schema)?; - - assert_eq!(result.schema(), expected_schema); - Ok(()) - } - #[test] fn project_plan_to_schema_errors_on_field_metadata_mismatch() { let (input, expected_schema) = field_metadata_mismatch(); diff --git a/datafusion/physical-plan/src/recursive_query.rs b/datafusion/physical-plan/src/recursive_query.rs index 2bb32c39cf3d8..14a306f4b36aa 100644 --- a/datafusion/physical-plan/src/recursive_query.rs +++ b/datafusion/physical-plan/src/recursive_query.rs @@ -24,25 +24,25 @@ use std::task::{Context, Poll}; use super::work_table::{ReservedBatches, WorkTable}; use crate::aggregates::group_values::{GroupValues, new_group_values}; use crate::aggregates::order::GroupOrdering; -use crate::common::align_plan_to_schema; +use crate::common::project_plan_to_schema; use crate::execution_plan::{Boundedness, EmissionType, reset_plan_states}; use crate::metrics::{ BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RecordOutput, }; +use crate::stream::RecordBatchStreamAdapter; use crate::{ DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, RecordBatchStream, SendableRecordBatchStream, }; use arrow::array::{BooleanArray, BooleanBuilder}; use arrow::compute::filter_record_batch; -use arrow::datatypes::{Field, Schema, SchemaRef}; +use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use datafusion_common::tree_node::TreeNodeRecursion; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{Result, internal_datafusion_err, not_impl_err}; use datafusion_execution::TaskContext; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; -use datafusion_expr::expr::intersect_metadata_for_union; use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; @@ -87,20 +87,18 @@ impl RecursiveQueryExec { name: String, static_term: Arc, recursive_term: Arc, + output_schema: SchemaRef, is_distinct: bool, ) -> Result { // Each recursive query needs its own work table let work_table = Arc::new(WorkTable::new(name.clone())); - // Use static term field names with nullability widened across static and - // recursive terms. Align both children at plan construction time instead - // of patching batches in RecursiveQueryStream. + // The logical recursive CTE schema is authoritative. Align both + // children at plan construction time instead of patching batches in + // RecursiveQueryStream. let recursive_term = assign_work_table(recursive_term, &work_table)?; - let output_schema = recursive_query_output_schema( - &static_term.schema(), - &recursive_term.schema(), - )?; - let static_term = align_plan_to_schema(static_term, &output_schema)?; - let recursive_term = align_plan_to_schema(recursive_term, &output_schema)?; + let static_term = align_recursive_plan_to_schema(static_term, &output_schema)?; + let recursive_term = + align_recursive_plan_to_schema(recursive_term, &output_schema)?; let cache = Self::compute_properties(Arc::clone(&output_schema)); Ok(RecursiveQueryExec { name, @@ -191,6 +189,7 @@ impl ExecutionPlan for RecursiveQueryExec { self.name.clone(), Arc::clone(&children[0]), Arc::clone(&children[1]), + self.schema(), self.is_distinct, ) .map(|e| Arc::new(e) as _) @@ -375,52 +374,160 @@ impl RecursiveQueryStream { } } -fn recursive_query_output_schema( - static_schema: &SchemaRef, - recursive_schema: &SchemaRef, -) -> Result { - if static_schema.fields().len() != recursive_schema.fields().len() { +fn align_recursive_plan_to_schema( + input: Arc, + output_schema: &SchemaRef, +) -> Result> { + match project_plan_to_schema(Arc::clone(&input), output_schema) { + Ok(projected) => Ok(projected), + Err(_) => Ok(Arc::new(RecursiveSchemaRebindExec::try_new( + input, + Arc::clone(output_schema), + )?)), + } +} + +#[derive(Debug, Clone)] +struct RecursiveSchemaRebindExec { + input: Arc, + schema: SchemaRef, + cache: Arc, +} + +impl RecursiveSchemaRebindExec { + fn try_new(input: Arc, schema: SchemaRef) -> Result { + validate_recursive_schema_rebind(&input.schema(), &schema)?; + let input_properties = input.properties(); + let properties = PlanProperties::new( + EquivalenceProperties::new(Arc::clone(&schema)), + Partitioning::UnknownPartitioning( + input_properties.partitioning.partition_count(), + ), + input_properties.emission_type, + input_properties.boundedness, + ) + .with_evaluation_type(input_properties.evaluation_type) + .with_scheduling_type(input_properties.scheduling_type); + + Ok(Self { + input, + schema, + cache: Arc::new(properties), + }) + } +} + +impl DisplayAs for RecursiveSchemaRebindExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "RecursiveSchemaRebindExec") + } + DisplayFormatType::TreeRender => Ok(()), + } + } +} + +impl ExecutionPlan for RecursiveSchemaRebindExec { + fn name(&self) -> &'static str { + "RecursiveSchemaRebindExec" + } + + fn properties(&self) -> &Arc { + &self.cache + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn apply_expressions( + &self, + _f: &mut dyn FnMut(&dyn PhysicalExpr) -> Result, + ) -> Result { + Ok(TreeNodeRecursion::Continue) + } + + fn maintains_input_order(&self) -> Vec { + vec![true] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + let [input] = children.try_into().map_err(|children: Vec<_>| { + internal_datafusion_err!( + "RecursiveSchemaRebindExec expected 1 child, got {}", + children.len() + ) + })?; + Ok(Arc::new(Self::try_new(input, Arc::clone(&self.schema))?)) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + let schema = Arc::clone(&self.schema); + let stream = self.input.execute(partition, context)?.map({ + let schema = Arc::clone(&schema); + move |batch| { + let batch = batch?; + if batch.schema().as_ref() == schema.as_ref() { + Ok(batch) + } else { + RecordBatch::try_new(Arc::clone(&schema), batch.columns().to_vec()) + .map_err(Into::into) + } + } + }); + Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream))) + } +} + +fn validate_recursive_schema_rebind( + input_schema: &SchemaRef, + output_schema: &SchemaRef, +) -> Result<()> { + if input_schema.fields().len() != output_schema.fields().len() { return datafusion_common::plan_err!( - "RecursiveQueryExec static and recursive terms have different number of columns: {} != {}", - static_schema.fields().len(), - recursive_schema.fields().len() + "RecursiveQueryExec input and output schemas have different number of columns: {} != {}", + input_schema.fields().len(), + output_schema.fields().len() ); } - let fields = static_schema + if let Some((i, input_field, output_field, mismatch)) = input_schema .fields() .iter() - .zip(recursive_schema.fields().iter()) + .zip(output_schema.fields().iter()) .enumerate() - .map(|(i, (static_field, recursive_field))| { - if static_field.data_type() != recursive_field.data_type() { - return datafusion_common::plan_err!( - "RecursiveQueryExec column {i} has different types: static term has {} whereas recursive term has {}", - static_field.data_type(), - recursive_field.data_type() - ); + .find_map(|(i, (input_field, output_field))| { + if input_field.data_type() != output_field.data_type() { + Some((i, input_field, output_field, "type")) + } else if input_field.is_nullable() && !output_field.is_nullable() { + Some((i, input_field, output_field, "nullability")) + } else { + None } - Ok(Arc::new( - Field::new( - static_field.name(), - static_field.data_type().clone(), - static_field.is_nullable() || recursive_field.is_nullable(), - ) - .with_metadata(intersect_metadata_for_union([ - static_field.metadata(), - recursive_field.metadata(), - ])), - )) }) - .collect::>>()?; - - Ok(Arc::new(Schema::new_with_metadata( - fields, - intersect_metadata_for_union([ - static_schema.metadata(), - recursive_schema.metadata(), - ]), - ))) + { + return datafusion_common::plan_err!( + "Cannot align recursive query column {i} ('{}') to output field '{}': field {mismatch} differs (input field: {:?}, output field: {:?})", + input_field.name(), + output_field.name(), + input_field, + output_field + ); + } + + Ok(()) } fn assign_work_table( @@ -566,11 +673,13 @@ mod tests { fn recursive_exec( static_term: Arc, recursive_term: Arc, + output_schema: SchemaRef, ) -> Result { RecursiveQueryExec::try_new( "numbers".to_string(), static_term, recursive_term, + output_schema, false, ) } @@ -581,7 +690,11 @@ mod tests { let recursive_term = empty_exec(vec![Field::new("value + Int32(1)", DataType::Int32, false)]); - let exec = recursive_exec(Arc::clone(&static_term), Arc::clone(&recursive_term))?; + let exec = recursive_exec( + Arc::clone(&static_term), + Arc::clone(&recursive_term), + static_term.schema(), + )?; assert_eq!(exec.schema(), static_term.schema()); let projection = exec @@ -601,13 +714,16 @@ mod tests { let recursive_term = empty_exec(vec![Field::new("value + Int32(1)", DataType::Int32, true)]); - let exec = recursive_exec(Arc::clone(&static_term), Arc::clone(&recursive_term))?; - let expected_schema = Arc::new(Schema::new(vec![Field::new( "value", DataType::Int32, true, )])); + let exec = recursive_exec( + Arc::clone(&static_term), + Arc::clone(&recursive_term), + Arc::clone(&expected_schema), + )?; assert_eq!(exec.schema(), expected_schema); assert_eq!(exec.static_term().schema(), expected_schema); assert_eq!(exec.recursive_term().schema(), expected_schema); @@ -642,9 +758,16 @@ mod tests { ]), )); + let expected_schema = Arc::new(Schema::new_with_metadata( + vec![Field::new("value", DataType::Int32, false).with_metadata( + HashMap::from([("shared".to_string(), "same".to_string())]), + )], + HashMap::from([("shared".to_string(), "same".to_string())]), + )); let exec = recursive_exec( empty_exec_with_schema(static_schema), empty_exec_with_schema(recursive_schema), + expected_schema, )?; assert_eq!( From 07112de4d8f149c367625abba33d4d7aa8f314c3 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Mon, 11 May 2026 09:26:44 +0800 Subject: [PATCH 11/24] feat(datafusion): enhance recursive query handling and error management - Renamed helper function from `align_recursive_plan_to_schema` to `align_recursive_child_to_logical_schema`. - Updated fallback mechanism to preserve `project_plan_to_schema` errors when local rebind cannot handle cases safely. - `RecursiveSchemaRebindExec` now rejects: - Schema metadata mismatches - Field metadata mismatches - Column count mismatches - Type mismatches - Maintained support for nullability-only schema rebind. - Updated tests to include: - Nullability rebind test - Field metadata rejection test - Schema metadata rejection test --- .../physical-plan/src/recursive_query.rs | 124 +++++++++++------- 1 file changed, 74 insertions(+), 50 deletions(-) diff --git a/datafusion/physical-plan/src/recursive_query.rs b/datafusion/physical-plan/src/recursive_query.rs index 14a306f4b36aa..04c8e2811a587 100644 --- a/datafusion/physical-plan/src/recursive_query.rs +++ b/datafusion/physical-plan/src/recursive_query.rs @@ -96,9 +96,10 @@ impl RecursiveQueryExec { // children at plan construction time instead of patching batches in // RecursiveQueryStream. let recursive_term = assign_work_table(recursive_term, &work_table)?; - let static_term = align_recursive_plan_to_schema(static_term, &output_schema)?; + let static_term = + align_recursive_child_to_logical_schema(static_term, &output_schema)?; let recursive_term = - align_recursive_plan_to_schema(recursive_term, &output_schema)?; + align_recursive_child_to_logical_schema(recursive_term, &output_schema)?; let cache = Self::compute_properties(Arc::clone(&output_schema)); Ok(RecursiveQueryExec { name, @@ -374,16 +375,18 @@ impl RecursiveQueryStream { } } -fn align_recursive_plan_to_schema( +fn align_recursive_child_to_logical_schema( input: Arc, output_schema: &SchemaRef, ) -> Result> { match project_plan_to_schema(Arc::clone(&input), output_schema) { Ok(projected) => Ok(projected), - Err(_) => Ok(Arc::new(RecursiveSchemaRebindExec::try_new( - input, - Arc::clone(output_schema), - )?)), + Err(projection_error) => { + match RecursiveSchemaRebindExec::try_new(input, Arc::clone(output_schema)) { + Ok(exec) => Ok(Arc::new(exec)), + Err(_) => Err(projection_error), + } + } } } @@ -503,6 +506,12 @@ fn validate_recursive_schema_rebind( ); } + if input_schema.metadata() != output_schema.metadata() { + return datafusion_common::plan_err!( + "Cannot align recursive query input to output schema: schema metadata differ" + ); + } + if let Some((i, input_field, output_field, mismatch)) = input_schema .fields() .iter() @@ -511,8 +520,8 @@ fn validate_recursive_schema_rebind( .find_map(|(i, (input_field, output_field))| { if input_field.data_type() != output_field.data_type() { Some((i, input_field, output_field, "type")) - } else if input_field.is_nullable() && !output_field.is_nullable() { - Some((i, input_field, output_field, "nullability")) + } else if input_field.metadata() != output_field.metadata() { + Some((i, input_field, output_field, "metadata")) } else { None } @@ -732,52 +741,67 @@ mod tests { } #[test] - fn recursive_query_exec_intersects_output_metadata() -> Result<()> { - let static_field = - Field::new("value", DataType::Int32, false).with_metadata(HashMap::from([ - ("shared".to_string(), "same".to_string()), - ("static".to_string(), "only".to_string()), - ])); - let recursive_field = - Field::new("value", DataType::Int32, false).with_metadata(HashMap::from([ - ("shared".to_string(), "same".to_string()), - ("static".to_string(), "different".to_string()), - ])); - let static_schema = Arc::new(Schema::new_with_metadata( - vec![static_field], - HashMap::from([ - ("shared".to_string(), "same".to_string()), - ("static".to_string(), "only".to_string()), - ]), - )); - let recursive_schema = Arc::new(Schema::new_with_metadata( - vec![recursive_field], - HashMap::from([ - ("shared".to_string(), "same".to_string()), - ("static".to_string(), "different".to_string()), - ]), - )); + fn recursive_query_exec_uses_rebind_for_nullability_narrowing() -> Result<()> { + let static_term = empty_exec(vec![Field::new("value", DataType::Int32, false)]); + let recursive_term = empty_exec(vec![Field::new("value", DataType::Int32, true)]); + let output_schema = static_term.schema(); - let expected_schema = Arc::new(Schema::new_with_metadata( - vec![Field::new("value", DataType::Int32, false).with_metadata( - HashMap::from([("shared".to_string(), "same".to_string())]), - )], - HashMap::from([("shared".to_string(), "same".to_string())]), - )); let exec = recursive_exec( - empty_exec_with_schema(static_schema), - empty_exec_with_schema(recursive_schema), - expected_schema, + Arc::clone(&static_term), + Arc::clone(&recursive_term), + Arc::clone(&output_schema), )?; - assert_eq!( - exec.schema().field(0).metadata(), - &HashMap::from([("shared".to_string(), "same".to_string())]) - ); - assert_eq!( - exec.schema().metadata(), - &HashMap::from([("shared".to_string(), "same".to_string())]) + assert_eq!(exec.schema(), output_schema); + assert_eq!(exec.recursive_term().schema(), output_schema); + assert!( + exec.recursive_term() + .downcast_ref::() + .is_some() ); Ok(()) } + + #[test] + fn recursive_query_exec_rejects_field_metadata_mismatch() { + let input_metadata = HashMap::from([("source".to_string(), "input".to_string())]); + let output_metadata = + HashMap::from([("source".to_string(), "output".to_string())]); + let static_schema = Arc::new(Schema::new(vec![ + Field::new("value", DataType::Int32, false).with_metadata(input_metadata), + ])); + let output_schema = Arc::new(Schema::new(vec![ + Field::new("value", DataType::Int32, false).with_metadata(output_metadata), + ])); + + let err = recursive_exec( + empty_exec_with_schema(static_schema), + empty_exec(vec![Field::new("value", DataType::Int32, false)]), + output_schema, + ) + .unwrap_err(); + + assert!(err.to_string().contains("field metadata differs")); + } + + #[test] + fn recursive_query_exec_rejects_schema_metadata_mismatch() { + let static_schema = Arc::new(Schema::new_with_metadata( + vec![Field::new("value", DataType::Int32, false)], + HashMap::from([("source".to_string(), "input".to_string())]), + )); + let output_schema = Arc::new(Schema::new_with_metadata( + vec![Field::new("value", DataType::Int32, false)], + HashMap::from([("source".to_string(), "output".to_string())]), + )); + + let err = recursive_exec( + empty_exec_with_schema(static_schema), + empty_exec(vec![Field::new("value", DataType::Int32, false)]), + output_schema, + ) + .unwrap_err(); + + assert!(err.to_string().contains("schema metadata differ")); + } } From 4c17d7a83ac57c3e3f81c908444fb1dadf157fe9 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Mon, 11 May 2026 10:05:41 +0800 Subject: [PATCH 12/24] feat: update recursive_query and physical_planner to use references - Updated function signature in recursive_query to take a reference - Updated internal call site in with_new_children to accommodate the change - Modified test helper and all affected test call sites in recursive_query - Updated planner call site in physical_planner to align with new function signature --- datafusion/core/src/physical_planner.rs | 2 +- .../physical-plan/src/recursive_query.rs | 22 +++++++++---------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 7304c197c6602..66f40547ebd4e 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -1787,7 +1787,7 @@ impl DefaultPhysicalPlanner { name.clone(), static_term, recursive_term, - Arc::clone(schema.inner()), + schema.inner(), *is_distinct, )?) } diff --git a/datafusion/physical-plan/src/recursive_query.rs b/datafusion/physical-plan/src/recursive_query.rs index 04c8e2811a587..5d00e7b350e71 100644 --- a/datafusion/physical-plan/src/recursive_query.rs +++ b/datafusion/physical-plan/src/recursive_query.rs @@ -87,7 +87,7 @@ impl RecursiveQueryExec { name: String, static_term: Arc, recursive_term: Arc, - output_schema: SchemaRef, + output_schema: &SchemaRef, is_distinct: bool, ) -> Result { // Each recursive query needs its own work table @@ -97,10 +97,10 @@ impl RecursiveQueryExec { // RecursiveQueryStream. let recursive_term = assign_work_table(recursive_term, &work_table)?; let static_term = - align_recursive_child_to_logical_schema(static_term, &output_schema)?; + align_recursive_child_to_logical_schema(static_term, output_schema)?; let recursive_term = - align_recursive_child_to_logical_schema(recursive_term, &output_schema)?; - let cache = Self::compute_properties(Arc::clone(&output_schema)); + align_recursive_child_to_logical_schema(recursive_term, output_schema)?; + let cache = Self::compute_properties(Arc::clone(output_schema)); Ok(RecursiveQueryExec { name, static_term, @@ -190,7 +190,7 @@ impl ExecutionPlan for RecursiveQueryExec { self.name.clone(), Arc::clone(&children[0]), Arc::clone(&children[1]), - self.schema(), + &self.schema(), self.is_distinct, ) .map(|e| Arc::new(e) as _) @@ -682,7 +682,7 @@ mod tests { fn recursive_exec( static_term: Arc, recursive_term: Arc, - output_schema: SchemaRef, + output_schema: &SchemaRef, ) -> Result { RecursiveQueryExec::try_new( "numbers".to_string(), @@ -702,7 +702,7 @@ mod tests { let exec = recursive_exec( Arc::clone(&static_term), Arc::clone(&recursive_term), - static_term.schema(), + &static_term.schema(), )?; assert_eq!(exec.schema(), static_term.schema()); @@ -731,7 +731,7 @@ mod tests { let exec = recursive_exec( Arc::clone(&static_term), Arc::clone(&recursive_term), - Arc::clone(&expected_schema), + &expected_schema, )?; assert_eq!(exec.schema(), expected_schema); assert_eq!(exec.static_term().schema(), expected_schema); @@ -749,7 +749,7 @@ mod tests { let exec = recursive_exec( Arc::clone(&static_term), Arc::clone(&recursive_term), - Arc::clone(&output_schema), + &output_schema, )?; assert_eq!(exec.schema(), output_schema); @@ -777,7 +777,7 @@ mod tests { let err = recursive_exec( empty_exec_with_schema(static_schema), empty_exec(vec![Field::new("value", DataType::Int32, false)]), - output_schema, + &output_schema, ) .unwrap_err(); @@ -798,7 +798,7 @@ mod tests { let err = recursive_exec( empty_exec_with_schema(static_schema), empty_exec(vec![Field::new("value", DataType::Int32, false)]), - output_schema, + &output_schema, ) .unwrap_err(); From 365b9cabf61e549e85b60e9ab356b5fdfc1763ad Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Mon, 11 May 2026 14:11:32 +0800 Subject: [PATCH 13/24] Revert to 63f62a829: feat: enhance recursive query validation and testing --- datafusion/core/src/physical_planner.rs | 6 +- datafusion/physical-plan/src/common.rs | 330 ++++++++++++++++-- .../physical-plan/src/recursive_query.rs | 321 +++++------------ 3 files changed, 394 insertions(+), 263 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 66f40547ebd4e..3b2c7a78e898e 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -1777,17 +1777,13 @@ impl DefaultPhysicalPlanner { } } LogicalPlan::RecursiveQuery(RecursiveQuery { - name, - schema, - is_distinct, - .. + name, is_distinct, .. }) => { let [static_term, recursive_term] = children.two()?; Arc::new(RecursiveQueryExec::try_new( name.clone(), static_term, recursive_term, - schema.inner(), *is_distinct, )?) } diff --git a/datafusion/physical-plan/src/common.rs b/datafusion/physical-plan/src/common.rs index 5c29966a1f719..2e5848b1097e2 100644 --- a/datafusion/physical-plan/src/common.rs +++ b/datafusion/physical-plan/src/common.rs @@ -24,15 +24,21 @@ use std::sync::Arc; use super::SendableRecordBatchStream; use crate::expressions::{CastExpr, Column}; use crate::projection::{ProjectionExec, ProjectionExpr}; -use crate::stream::RecordBatchReceiverStream; -use crate::{ColumnStatistics, ExecutionPlan, Statistics}; +use crate::stream::{RecordBatchReceiverStream, RecordBatchStreamAdapter}; +use crate::{ + ColumnStatistics, DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, + PlanProperties, Statistics, +}; use arrow::array::Array; use arrow::datatypes::{Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_common::stats::Precision; +use datafusion_common::tree_node::TreeNodeRecursion; use datafusion_common::{Result, plan_err}; +use datafusion_execution::TaskContext; use datafusion_execution::memory_pool::MemoryReservation; +use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr}; use futures::{StreamExt, TryStreamExt}; @@ -90,32 +96,60 @@ fn build_file_list_recurse( Ok(()) } -/// Project `input` to `expected_schema` when [`ProjectionExec`] can produce that exact schema. +/// Align `input`'s physical plan schema with `expected_schema`. /// /// This helper is intended for operators that combine independently planned children but /// expose a single declared output schema. It returns `input` unchanged when schemas already -/// match exactly. Otherwise, it validates that projection can safely produce the expected -/// schema, then wraps `input` in a [`ProjectionExec`] that keeps columns in their existing -/// positional order and aliases them to `expected_schema`'s field names. +/// match exactly. Otherwise, it validates positional compatibility and uses a plan-time +/// adapter whose advertised and emitted schema is exactly `expected_schema`. /// -/// [`ProjectionExec`] can rename fields. When the expected field is nullable and the input -/// field is not, this helper also widens nullability with a same-type [`CastExpr`]. It rejects -/// differences that projection cannot safely normalize exactly, such as data type, metadata, -/// schema metadata, and nullability narrowing. -pub fn project_plan_to_schema( +/// Prefer this helper over rebinding batches inside a parent operator's stream. The alignment +/// is visible in the physical plan, while batch schema rebinding remains contained in the +/// adapter as the implementation detail required to uphold the plan-level schema contract. +/// +/// This helper can align field names, nullability, and metadata to the declared schema. It +/// rejects differences that would change values, such as column count or data type mismatches. +/// +/// When an adapter is required, it conservatively derives fresh equivalence properties from +/// `expected_schema` and drops child hash partitioning because field names/nullability may have +/// changed while the underlying partitioning expressions still refer to the child schema. +pub fn align_plan_to_schema( input: Arc, expected_schema: &SchemaRef, ) -> Result> { let input_schema = input.schema(); - if input_schema.fields().len() != expected_schema.fields().len() { - return plan_err!( - "Cannot project plan to expected schema: expected {} column(s), got {}", - expected_schema.fields().len(), - input_schema.fields().len() - ); + if input_schema.as_ref() == expected_schema.as_ref() { + return Ok(input); } + // Projection is the preferred adapter, but not every valid schema-only + // alignment can be represented by ProjectionExec (for example nullability + // narrowing). Treat projection errors as path-selection only; if the + // fallback also fails, SchemaAlignExec returns the final diagnostic. + if let Ok(projected) = project_plan_to_schema(Arc::clone(&input), expected_schema) { + debug_assert_eq!(projected.schema().as_ref(), expected_schema.as_ref()); + return Ok(projected); + } + + Ok(Arc::new(SchemaAlignExec::try_new( + input, + Arc::clone(expected_schema), + )?)) +} + +/// Project `input` to `expected_schema` when [`ProjectionExec`] can produce that exact schema. +/// +/// This is a narrower helper than [`align_plan_to_schema`]. It is useful when a positional +/// projection/alias is sufficient. It rejects requests where ProjectionExec cannot advertise the +/// exact expected schema, such as nullability narrowing or metadata changes. +pub fn project_plan_to_schema( + input: Arc, + expected_schema: &SchemaRef, +) -> Result> { + let input_schema = input.schema(); + validate_schema_alignment(&input_schema, expected_schema, "project")?; + if input_schema.as_ref() == expected_schema.as_ref() { return Ok(input); } @@ -132,13 +166,12 @@ pub fn project_plan_to_schema( .zip(expected_schema.fields().iter()) .enumerate() .find_map(|(i, (input_field, expected_field))| { - if input_field.data_type() != expected_field.data_type() { - Some((i, input_field, expected_field, "data type")) - } else if input_field.metadata() != expected_field.metadata() { - Some((i, input_field, expected_field, "metadata")) - } else { - None - } + (input_field.metadata() != expected_field.metadata()).then_some(( + i, + input_field, + expected_field, + "metadata", + )) }) { return plan_err!( @@ -202,6 +235,173 @@ pub fn project_plan_to_schema( Ok(Arc::new(projection)) } +fn validate_schema_alignment( + input_schema: &SchemaRef, + expected_schema: &SchemaRef, + operation: &str, +) -> Result<()> { + if input_schema.fields().len() != expected_schema.fields().len() { + return plan_err!( + "Cannot {operation} plan to expected schema: expected {} column(s), got {}", + expected_schema.fields().len(), + input_schema.fields().len() + ); + } + + if let Some((i, input_field, expected_field, mismatch)) = input_schema + .fields() + .iter() + .zip(expected_schema.fields().iter()) + .enumerate() + .find_map(|(i, (input_field, expected_field))| { + if input_field.data_type() != expected_field.data_type() { + Some((i, input_field, expected_field, "data type")) + } else { + None + } + }) + { + return plan_err!( + "Cannot {operation} plan column {i} ('{}') to expected output field '{}': \ + field {mismatch} differs (input field: {:?}, expected field: {:?})", + input_field.name(), + expected_field.name(), + input_field, + expected_field + ); + } + + Ok(()) +} + +/// Plan-time schema adapter for positional schema alignment. +/// +/// [`ProjectionExec`] cannot express every schema-only alignment. In particular, a column +/// expression remains nullable when its input field is nullable, so projection cannot advertise +/// a non-null expected field. This adapter is for cases where the operator-level contract has +/// already established that columns are positionally compatible and the child plan must expose +/// the declared schema exactly. +#[derive(Debug, Clone)] +pub struct SchemaAlignExec { + input: Arc, + schema: SchemaRef, + cache: Arc, +} + +impl SchemaAlignExec { + /// Create a new schema alignment adapter. + pub fn try_new(input: Arc, schema: SchemaRef) -> Result { + validate_schema_alignment(&input.schema(), &schema, "align")?; + + let input_properties = input.properties(); + let partitioning = match &input_properties.partitioning { + Partitioning::RoundRobinBatch(partitions) => { + Partitioning::RoundRobinBatch(*partitions) + } + partitioning => { + Partitioning::UnknownPartitioning(partitioning.partition_count()) + } + }; + let properties = PlanProperties::new( + EquivalenceProperties::new(Arc::clone(&schema)), + partitioning, + input_properties.emission_type, + input_properties.boundedness, + ) + .with_evaluation_type(input_properties.evaluation_type) + .with_scheduling_type(input_properties.scheduling_type); + + Ok(Self { + input, + schema, + cache: Arc::new(properties), + }) + } + + /// Input plan being aligned. + pub fn input(&self) -> &Arc { + &self.input + } +} + +impl DisplayAs for SchemaAlignExec { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "SchemaAlignExec") + } + DisplayFormatType::TreeRender => Ok(()), + } + } +} + +impl ExecutionPlan for SchemaAlignExec { + fn name(&self) -> &'static str { + "SchemaAlignExec" + } + + fn properties(&self) -> &Arc { + &self.cache + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn apply_expressions( + &self, + _f: &mut dyn FnMut(&dyn PhysicalExpr) -> Result, + ) -> Result { + Ok(TreeNodeRecursion::Continue) + } + + fn maintains_input_order(&self) -> Vec { + vec![true] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + let [input] = children.try_into().map_err(|children: Vec<_>| { + datafusion_common::DataFusionError::Internal(format!( + "SchemaAlignExec expected 1 child, got {}", + children.len() + )) + })?; + Ok(Arc::new(Self::try_new(input, Arc::clone(&self.schema))?)) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + let schema = Arc::clone(&self.schema); + let stream = self.input.execute(partition, context)?.map({ + let schema = Arc::clone(&schema); + move |batch| { + let batch = batch?; + if batch.schema().as_ref() == schema.as_ref() { + Ok(batch) + } else { + RecordBatch::try_new(Arc::clone(&schema), batch.columns().to_vec()) + .map_err(Into::into) + } + } + }); + Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream))) + } + + fn partition_statistics(&self, partition: Option) -> Result> { + self.input.partition_statistics(partition) + } +} + /// If running in a tokio context spawns the execution of `stream` to a separate task /// allowing it to execute in parallel with an intermediate buffer of size `buffer` pub fn spawn_buffered( @@ -566,6 +766,88 @@ mod tests { assert!(err.to_string().contains("field nullability differs")); } + #[test] + fn align_plan_to_schema_returns_input_when_schema_matches() -> Result<()> { + let schema = single_field_schema("value", DataType::Int32, false); + let input: Arc = Arc::new(EmptyExec::new(Arc::clone(&schema))); + + let result = align_plan_to_schema(Arc::clone(&input), &schema)?; + + assert!(Arc::ptr_eq(&input, &result)); + Ok(()) + } + + #[test] + fn align_plan_to_schema_uses_projection_for_rename_only() -> Result<()> { + let input = single_i32_exec("recursive_a", false); + let expected_schema = single_field_schema("a", DataType::Int32, false); + + let result = align_plan_to_schema(Arc::clone(&input), &expected_schema)?; + + let projection = result + .downcast_ref::() + .expect("rename-only alignment should use ProjectionExec"); + assert!(Arc::ptr_eq(projection.input(), &input)); + assert_eq!(projection.schema(), expected_schema); + Ok(()) + } + + #[test] + fn align_plan_to_schema_uses_adapter_for_nullability_narrowing() -> Result<()> { + let input = single_i32_exec("a", true); + let expected_schema = single_field_schema("renamed", DataType::Int32, false); + + let result = align_plan_to_schema(Arc::clone(&input), &expected_schema)?; + + let aligned = result + .downcast_ref::() + .expect("nullability narrowing should use SchemaAlignExec"); + assert!(Arc::ptr_eq(aligned.input(), &input)); + assert_eq!(aligned.schema(), expected_schema); + Ok(()) + } + + #[test] + fn align_plan_to_schema_errors_on_column_count_mismatch() { + let input = single_i32_exec("a", false); + let expected_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ])); + + let err = align_plan_to_schema(input, &expected_schema).unwrap_err(); + assert!(err.to_string().contains("expected 2 column")); + } + + #[test] + fn align_plan_to_schema_errors_on_type_mismatch() { + let input = single_i32_exec("a", false); + let expected_schema = single_field_schema("a", DataType::Float32, false); + + let err = align_plan_to_schema(input, &expected_schema).unwrap_err(); + assert!(err.to_string().contains("field data type differs")); + } + + #[test] + fn align_plan_to_schema_aligns_field_metadata() -> Result<()> { + let (input, expected_schema) = field_metadata_mismatch(); + + let result = align_plan_to_schema(input, &expected_schema)?; + + assert_eq!(result.schema(), expected_schema); + Ok(()) + } + + #[test] + fn align_plan_to_schema_aligns_schema_metadata() -> Result<()> { + let (input, expected_schema) = schema_metadata_mismatch(); + + let result = align_plan_to_schema(input, &expected_schema)?; + + assert_eq!(result.schema(), expected_schema); + Ok(()) + } + #[test] fn project_plan_to_schema_errors_on_field_metadata_mismatch() { let (input, expected_schema) = field_metadata_mismatch(); diff --git a/datafusion/physical-plan/src/recursive_query.rs b/datafusion/physical-plan/src/recursive_query.rs index 5d00e7b350e71..2bb32c39cf3d8 100644 --- a/datafusion/physical-plan/src/recursive_query.rs +++ b/datafusion/physical-plan/src/recursive_query.rs @@ -24,25 +24,25 @@ use std::task::{Context, Poll}; use super::work_table::{ReservedBatches, WorkTable}; use crate::aggregates::group_values::{GroupValues, new_group_values}; use crate::aggregates::order::GroupOrdering; -use crate::common::project_plan_to_schema; +use crate::common::align_plan_to_schema; use crate::execution_plan::{Boundedness, EmissionType, reset_plan_states}; use crate::metrics::{ BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RecordOutput, }; -use crate::stream::RecordBatchStreamAdapter; use crate::{ DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, RecordBatchStream, SendableRecordBatchStream, }; use arrow::array::{BooleanArray, BooleanBuilder}; use arrow::compute::filter_record_batch; -use arrow::datatypes::SchemaRef; +use arrow::datatypes::{Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_common::tree_node::TreeNodeRecursion; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{Result, internal_datafusion_err, not_impl_err}; use datafusion_execution::TaskContext; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; +use datafusion_expr::expr::intersect_metadata_for_union; use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; @@ -87,20 +87,21 @@ impl RecursiveQueryExec { name: String, static_term: Arc, recursive_term: Arc, - output_schema: &SchemaRef, is_distinct: bool, ) -> Result { // Each recursive query needs its own work table let work_table = Arc::new(WorkTable::new(name.clone())); - // The logical recursive CTE schema is authoritative. Align both - // children at plan construction time instead of patching batches in - // RecursiveQueryStream. + // Use static term field names with nullability widened across static and + // recursive terms. Align both children at plan construction time instead + // of patching batches in RecursiveQueryStream. let recursive_term = assign_work_table(recursive_term, &work_table)?; - let static_term = - align_recursive_child_to_logical_schema(static_term, output_schema)?; - let recursive_term = - align_recursive_child_to_logical_schema(recursive_term, output_schema)?; - let cache = Self::compute_properties(Arc::clone(output_schema)); + let output_schema = recursive_query_output_schema( + &static_term.schema(), + &recursive_term.schema(), + )?; + let static_term = align_plan_to_schema(static_term, &output_schema)?; + let recursive_term = align_plan_to_schema(recursive_term, &output_schema)?; + let cache = Self::compute_properties(Arc::clone(&output_schema)); Ok(RecursiveQueryExec { name, static_term, @@ -190,7 +191,6 @@ impl ExecutionPlan for RecursiveQueryExec { self.name.clone(), Arc::clone(&children[0]), Arc::clone(&children[1]), - &self.schema(), self.is_distinct, ) .map(|e| Arc::new(e) as _) @@ -375,168 +375,52 @@ impl RecursiveQueryStream { } } -fn align_recursive_child_to_logical_schema( - input: Arc, - output_schema: &SchemaRef, -) -> Result> { - match project_plan_to_schema(Arc::clone(&input), output_schema) { - Ok(projected) => Ok(projected), - Err(projection_error) => { - match RecursiveSchemaRebindExec::try_new(input, Arc::clone(output_schema)) { - Ok(exec) => Ok(Arc::new(exec)), - Err(_) => Err(projection_error), - } - } - } -} - -#[derive(Debug, Clone)] -struct RecursiveSchemaRebindExec { - input: Arc, - schema: SchemaRef, - cache: Arc, -} - -impl RecursiveSchemaRebindExec { - fn try_new(input: Arc, schema: SchemaRef) -> Result { - validate_recursive_schema_rebind(&input.schema(), &schema)?; - let input_properties = input.properties(); - let properties = PlanProperties::new( - EquivalenceProperties::new(Arc::clone(&schema)), - Partitioning::UnknownPartitioning( - input_properties.partitioning.partition_count(), - ), - input_properties.emission_type, - input_properties.boundedness, - ) - .with_evaluation_type(input_properties.evaluation_type) - .with_scheduling_type(input_properties.scheduling_type); - - Ok(Self { - input, - schema, - cache: Arc::new(properties), - }) - } -} - -impl DisplayAs for RecursiveSchemaRebindExec { - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - match t { - DisplayFormatType::Default | DisplayFormatType::Verbose => { - write!(f, "RecursiveSchemaRebindExec") - } - DisplayFormatType::TreeRender => Ok(()), - } - } -} - -impl ExecutionPlan for RecursiveSchemaRebindExec { - fn name(&self) -> &'static str { - "RecursiveSchemaRebindExec" - } - - fn properties(&self) -> &Arc { - &self.cache - } - - fn children(&self) -> Vec<&Arc> { - vec![&self.input] - } - - fn apply_expressions( - &self, - _f: &mut dyn FnMut(&dyn PhysicalExpr) -> Result, - ) -> Result { - Ok(TreeNodeRecursion::Continue) - } - - fn maintains_input_order(&self) -> Vec { - vec![true] - } - - fn with_new_children( - self: Arc, - children: Vec>, - ) -> Result> { - let [input] = children.try_into().map_err(|children: Vec<_>| { - internal_datafusion_err!( - "RecursiveSchemaRebindExec expected 1 child, got {}", - children.len() - ) - })?; - Ok(Arc::new(Self::try_new(input, Arc::clone(&self.schema))?)) - } - - fn execute( - &self, - partition: usize, - context: Arc, - ) -> Result { - let schema = Arc::clone(&self.schema); - let stream = self.input.execute(partition, context)?.map({ - let schema = Arc::clone(&schema); - move |batch| { - let batch = batch?; - if batch.schema().as_ref() == schema.as_ref() { - Ok(batch) - } else { - RecordBatch::try_new(Arc::clone(&schema), batch.columns().to_vec()) - .map_err(Into::into) - } - } - }); - Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream))) - } -} - -fn validate_recursive_schema_rebind( - input_schema: &SchemaRef, - output_schema: &SchemaRef, -) -> Result<()> { - if input_schema.fields().len() != output_schema.fields().len() { +fn recursive_query_output_schema( + static_schema: &SchemaRef, + recursive_schema: &SchemaRef, +) -> Result { + if static_schema.fields().len() != recursive_schema.fields().len() { return datafusion_common::plan_err!( - "RecursiveQueryExec input and output schemas have different number of columns: {} != {}", - input_schema.fields().len(), - output_schema.fields().len() + "RecursiveQueryExec static and recursive terms have different number of columns: {} != {}", + static_schema.fields().len(), + recursive_schema.fields().len() ); } - if input_schema.metadata() != output_schema.metadata() { - return datafusion_common::plan_err!( - "Cannot align recursive query input to output schema: schema metadata differ" - ); - } - - if let Some((i, input_field, output_field, mismatch)) = input_schema + let fields = static_schema .fields() .iter() - .zip(output_schema.fields().iter()) + .zip(recursive_schema.fields().iter()) .enumerate() - .find_map(|(i, (input_field, output_field))| { - if input_field.data_type() != output_field.data_type() { - Some((i, input_field, output_field, "type")) - } else if input_field.metadata() != output_field.metadata() { - Some((i, input_field, output_field, "metadata")) - } else { - None + .map(|(i, (static_field, recursive_field))| { + if static_field.data_type() != recursive_field.data_type() { + return datafusion_common::plan_err!( + "RecursiveQueryExec column {i} has different types: static term has {} whereas recursive term has {}", + static_field.data_type(), + recursive_field.data_type() + ); } + Ok(Arc::new( + Field::new( + static_field.name(), + static_field.data_type().clone(), + static_field.is_nullable() || recursive_field.is_nullable(), + ) + .with_metadata(intersect_metadata_for_union([ + static_field.metadata(), + recursive_field.metadata(), + ])), + )) }) - { - return datafusion_common::plan_err!( - "Cannot align recursive query column {i} ('{}') to output field '{}': field {mismatch} differs (input field: {:?}, output field: {:?})", - input_field.name(), - output_field.name(), - input_field, - output_field - ); - } - - Ok(()) + .collect::>>()?; + + Ok(Arc::new(Schema::new_with_metadata( + fields, + intersect_metadata_for_union([ + static_schema.metadata(), + recursive_schema.metadata(), + ]), + ))) } fn assign_work_table( @@ -682,13 +566,11 @@ mod tests { fn recursive_exec( static_term: Arc, recursive_term: Arc, - output_schema: &SchemaRef, ) -> Result { RecursiveQueryExec::try_new( "numbers".to_string(), static_term, recursive_term, - output_schema, false, ) } @@ -699,11 +581,7 @@ mod tests { let recursive_term = empty_exec(vec![Field::new("value + Int32(1)", DataType::Int32, false)]); - let exec = recursive_exec( - Arc::clone(&static_term), - Arc::clone(&recursive_term), - &static_term.schema(), - )?; + let exec = recursive_exec(Arc::clone(&static_term), Arc::clone(&recursive_term))?; assert_eq!(exec.schema(), static_term.schema()); let projection = exec @@ -723,16 +601,13 @@ mod tests { let recursive_term = empty_exec(vec![Field::new("value + Int32(1)", DataType::Int32, true)]); + let exec = recursive_exec(Arc::clone(&static_term), Arc::clone(&recursive_term))?; + let expected_schema = Arc::new(Schema::new(vec![Field::new( "value", DataType::Int32, true, )])); - let exec = recursive_exec( - Arc::clone(&static_term), - Arc::clone(&recursive_term), - &expected_schema, - )?; assert_eq!(exec.schema(), expected_schema); assert_eq!(exec.static_term().schema(), expected_schema); assert_eq!(exec.recursive_term().schema(), expected_schema); @@ -741,67 +616,45 @@ mod tests { } #[test] - fn recursive_query_exec_uses_rebind_for_nullability_narrowing() -> Result<()> { - let static_term = empty_exec(vec![Field::new("value", DataType::Int32, false)]); - let recursive_term = empty_exec(vec![Field::new("value", DataType::Int32, true)]); - let output_schema = static_term.schema(); - - let exec = recursive_exec( - Arc::clone(&static_term), - Arc::clone(&recursive_term), - &output_schema, - )?; - - assert_eq!(exec.schema(), output_schema); - assert_eq!(exec.recursive_term().schema(), output_schema); - assert!( - exec.recursive_term() - .downcast_ref::() - .is_some() - ); - Ok(()) - } - - #[test] - fn recursive_query_exec_rejects_field_metadata_mismatch() { - let input_metadata = HashMap::from([("source".to_string(), "input".to_string())]); - let output_metadata = - HashMap::from([("source".to_string(), "output".to_string())]); - let static_schema = Arc::new(Schema::new(vec![ - Field::new("value", DataType::Int32, false).with_metadata(input_metadata), - ])); - let output_schema = Arc::new(Schema::new(vec![ - Field::new("value", DataType::Int32, false).with_metadata(output_metadata), - ])); - - let err = recursive_exec( - empty_exec_with_schema(static_schema), - empty_exec(vec![Field::new("value", DataType::Int32, false)]), - &output_schema, - ) - .unwrap_err(); - - assert!(err.to_string().contains("field metadata differs")); - } - - #[test] - fn recursive_query_exec_rejects_schema_metadata_mismatch() { + fn recursive_query_exec_intersects_output_metadata() -> Result<()> { + let static_field = + Field::new("value", DataType::Int32, false).with_metadata(HashMap::from([ + ("shared".to_string(), "same".to_string()), + ("static".to_string(), "only".to_string()), + ])); + let recursive_field = + Field::new("value", DataType::Int32, false).with_metadata(HashMap::from([ + ("shared".to_string(), "same".to_string()), + ("static".to_string(), "different".to_string()), + ])); let static_schema = Arc::new(Schema::new_with_metadata( - vec![Field::new("value", DataType::Int32, false)], - HashMap::from([("source".to_string(), "input".to_string())]), + vec![static_field], + HashMap::from([ + ("shared".to_string(), "same".to_string()), + ("static".to_string(), "only".to_string()), + ]), )); - let output_schema = Arc::new(Schema::new_with_metadata( - vec![Field::new("value", DataType::Int32, false)], - HashMap::from([("source".to_string(), "output".to_string())]), + let recursive_schema = Arc::new(Schema::new_with_metadata( + vec![recursive_field], + HashMap::from([ + ("shared".to_string(), "same".to_string()), + ("static".to_string(), "different".to_string()), + ]), )); - let err = recursive_exec( + let exec = recursive_exec( empty_exec_with_schema(static_schema), - empty_exec(vec![Field::new("value", DataType::Int32, false)]), - &output_schema, - ) - .unwrap_err(); + empty_exec_with_schema(recursive_schema), + )?; - assert!(err.to_string().contains("schema metadata differ")); + assert_eq!( + exec.schema().field(0).metadata(), + &HashMap::from([("shared".to_string(), "same".to_string())]) + ); + assert_eq!( + exec.schema().metadata(), + &HashMap::from([("shared".to_string(), "same".to_string())]) + ); + Ok(()) } } From 973f93e094c56248b8876f7b19b1e113ae7e0cf1 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Mon, 11 May 2026 15:47:45 +0800 Subject: [PATCH 14/24] feat: enhance recursive query handling by aligning schemas and preserving metadata --- datafusion/expr/src/logical_plan/builder.rs | 34 ++ datafusion/expr/src/logical_plan/plan.rs | 71 ++- datafusion/physical-plan/src/common.rs | 435 +++--------------- .../physical-plan/src/recursive_query.rs | 149 +++++- 4 files changed, 279 insertions(+), 410 deletions(-) diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 888b252d2cbcc..4c97c2e9b5d1b 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -2291,6 +2291,7 @@ pub fn unnest_with_options( #[cfg(test)] mod tests { + use std::collections::HashMap; use std::vec; use super::*; @@ -2386,6 +2387,39 @@ mod tests { Ok(()) } + #[test] + fn recursive_query_schema_preserves_static_metadata() -> Result<()> { + let static_metadata = + HashMap::from([("source".to_string(), "static".to_string())]); + let recursive_metadata = + HashMap::from([("source".to_string(), "recursive".to_string())]); + let static_schema = Schema::new_with_metadata( + vec![ + Field::new("n", DataType::Int32, false) + .with_metadata(static_metadata.clone()), + ], + static_metadata.clone(), + ); + let recursive_schema = Schema::new_with_metadata( + vec![ + Field::new("recursive_n", DataType::Int32, false) + .with_metadata(recursive_metadata), + ], + HashMap::from([("source".to_string(), "recursive".to_string())]), + ); + + let static_term = table_scan(Some("static_t"), &static_schema, None)?; + let recursive_term = + table_scan(Some("recursive_t"), &recursive_schema, None)?.build()?; + let plan = static_term + .to_recursive_query("t".to_string(), recursive_term, false)? + .build()?; + + assert_eq!(plan.schema().field(0).metadata(), &static_metadata); + assert_eq!(plan.schema().metadata(), &static_metadata); + Ok(()) + } + #[test] fn plan_builder_union() -> Result<()> { let plan = diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 9960de040099a..d3ba6635910c6 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -353,7 +353,9 @@ impl LogicalPlan { LogicalPlan::Copy(CopyTo { output_schema, .. }) => output_schema, LogicalPlan::Ddl(ddl) => ddl.schema(), LogicalPlan::Unnest(Unnest { schema, .. }) => schema, - LogicalPlan::RecursiveQuery(RecursiveQuery { schema, .. }) => schema, + LogicalPlan::RecursiveQuery(RecursiveQuery { static_term, .. }) => { + static_term.schema() + } } } @@ -2252,9 +2254,6 @@ pub struct RecursiveQuery { /// The recursive term (evaluated on the contents of the working table until /// it returns an empty set) pub recursive_term: Arc, - /// Output schema, using static term field names and nullability widened - /// across both static and recursive terms. - pub schema: DFSchemaRef, /// Should the output of the recursive term be deduplicated (`UNION`) or /// not (`UNION ALL`). pub is_distinct: bool, @@ -2271,16 +2270,36 @@ impl RecursiveQuery { ) -> Result { let schema = recursive_query_schema(static_term.schema(), recursive_term.schema())?; + let static_term = align_logical_plan_to_schema(static_term, schema)?; Ok(Self { name, static_term, recursive_term, - schema, is_distinct, }) } } +fn align_logical_plan_to_schema( + input: Arc, + schema: DFSchemaRef, +) -> Result> { + if input.schema().as_ref() == schema.as_ref() { + return Ok(input); + } + + let expr = input + .schema() + .fields() + .iter() + .enumerate() + .map(|(i, _)| Expr::Column(Column::from(input.schema().qualified_field(i)))) + .collect(); + Ok(Arc::new(LogicalPlan::Projection( + Projection::try_new_with_schema(expr, input, schema)?, + ))) +} + fn recursive_query_schema( static_schema: &DFSchema, recursive_schema: &DFSchema, @@ -2312,22 +2331,17 @@ fn recursive_query_schema( static_field.data_type().clone(), static_field.is_nullable() || recursive_field.is_nullable(), ) - .with_metadata(intersect_metadata_for_union([ - static_field.metadata(), - recursive_field.metadata(), - ])); + .with_metadata(static_field.metadata().clone()); Ok((qualifier.cloned(), Arc::new(field))) }) .collect::>>()?; - let metadata = intersect_metadata_for_union([ - static_schema.metadata(), - recursive_schema.metadata(), - ]); - Ok(Arc::new(DFSchema::new_with_metadata(fields, metadata)?)) + Ok(Arc::new(DFSchema::new_with_metadata( + fields, + static_schema.metadata().clone(), + )?)) } -// Manual implementation needed because of `schema` field. Comparison excludes this field. impl PartialOrd for RecursiveQuery { fn partial_cmp(&self, other: &Self) -> Option { ( @@ -4966,6 +4980,33 @@ mod tests { })) } + #[test] + fn recursive_query_try_new_aligns_static_term_to_widened_schema() -> Result<()> { + let static_term = + empty_plan_with_fields(vec![Field::new("a", DataType::Int32, false)]); + let recursive_term = + empty_plan_with_fields(vec![Field::new("b", DataType::Int32, true)]); + + let query = RecursiveQuery::try_new( + "t".to_string(), + Arc::clone(&static_term), + Arc::clone(&recursive_term), + false, + )?; + + assert_eq!(query.static_term.schema().field(0).name(), "a"); + assert!(query.static_term.schema().field(0).is_nullable()); + assert!(matches!( + query.static_term.as_ref(), + LogicalPlan::Projection(_) + )); + assert!( + Arc::ptr_eq(&query.recursive_term, &recursive_term), + "recursive term should not be wrapped in a schema-only Projection" + ); + Ok(()) + } + #[test] fn recursive_query_try_new_rejects_mismatched_column_count() { let static_term = diff --git a/datafusion/physical-plan/src/common.rs b/datafusion/physical-plan/src/common.rs index 2e5848b1097e2..0dafcf6bd3390 100644 --- a/datafusion/physical-plan/src/common.rs +++ b/datafusion/physical-plan/src/common.rs @@ -24,21 +24,15 @@ use std::sync::Arc; use super::SendableRecordBatchStream; use crate::expressions::{CastExpr, Column}; use crate::projection::{ProjectionExec, ProjectionExpr}; -use crate::stream::{RecordBatchReceiverStream, RecordBatchStreamAdapter}; -use crate::{ - ColumnStatistics, DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, - PlanProperties, Statistics, -}; +use crate::stream::RecordBatchReceiverStream; +use crate::{ColumnStatistics, ExecutionPlan, Statistics}; use arrow::array::Array; use arrow::datatypes::{Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_common::stats::Precision; -use datafusion_common::tree_node::TreeNodeRecursion; use datafusion_common::{Result, plan_err}; -use datafusion_execution::TaskContext; use datafusion_execution::memory_pool::MemoryReservation; -use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr}; use futures::{StreamExt, TryStreamExt}; @@ -100,58 +94,29 @@ fn build_file_list_recurse( /// /// This helper is intended for operators that combine independently planned children but /// expose a single declared output schema. It returns `input` unchanged when schemas already -/// match exactly. Otherwise, it validates positional compatibility and uses a plan-time -/// adapter whose advertised and emitted schema is exactly `expected_schema`. +/// match exactly. Otherwise, it validates that projection can safely produce the expected +/// schema, then wraps `input` in a [`ProjectionExec`] that keeps columns in their existing +/// positional order and aliases them to `expected_schema`'s field names. /// -/// Prefer this helper over rebinding batches inside a parent operator's stream. The alignment -/// is visible in the physical plan, while batch schema rebinding remains contained in the -/// adapter as the implementation detail required to uphold the plan-level schema contract. -/// -/// This helper can align field names, nullability, and metadata to the declared schema. It -/// rejects differences that would change values, such as column count or data type mismatches. -/// -/// When an adapter is required, it conservatively derives fresh equivalence properties from -/// `expected_schema` and drops child hash partitioning because field names/nullability may have -/// changed while the underlying partitioning expressions still refer to the child schema. -pub fn align_plan_to_schema( +/// [`ProjectionExec`] can rename fields. When the expected field is nullable and the input +/// field is not, this helper also widens nullability with a same-type [`CastExpr`]. It rejects +/// differences that projection cannot safely normalize exactly, such as data type, metadata, +/// schema metadata, and nullability narrowing. +pub fn project_plan_to_schema( input: Arc, expected_schema: &SchemaRef, ) -> Result> { let input_schema = input.schema(); - if input_schema.as_ref() == expected_schema.as_ref() { return Ok(input); } - // Projection is the preferred adapter, but not every valid schema-only - // alignment can be represented by ProjectionExec (for example nullability - // narrowing). Treat projection errors as path-selection only; if the - // fallback also fails, SchemaAlignExec returns the final diagnostic. - if let Ok(projected) = project_plan_to_schema(Arc::clone(&input), expected_schema) { - debug_assert_eq!(projected.schema().as_ref(), expected_schema.as_ref()); - return Ok(projected); - } - - Ok(Arc::new(SchemaAlignExec::try_new( - input, - Arc::clone(expected_schema), - )?)) -} - -/// Project `input` to `expected_schema` when [`ProjectionExec`] can produce that exact schema. -/// -/// This is a narrower helper than [`align_plan_to_schema`]. It is useful when a positional -/// projection/alias is sufficient. It rejects requests where ProjectionExec cannot advertise the -/// exact expected schema, such as nullability narrowing or metadata changes. -pub fn project_plan_to_schema( - input: Arc, - expected_schema: &SchemaRef, -) -> Result> { - let input_schema = input.schema(); - validate_schema_alignment(&input_schema, expected_schema, "project")?; - - if input_schema.as_ref() == expected_schema.as_ref() { - return Ok(input); + if input_schema.fields().len() != expected_schema.fields().len() { + return plan_err!( + "Cannot project plan to expected schema: expected {} column(s), got {}", + expected_schema.fields().len(), + input_schema.fields().len() + ); } if input_schema.metadata() != expected_schema.metadata() { @@ -166,12 +131,15 @@ pub fn project_plan_to_schema( .zip(expected_schema.fields().iter()) .enumerate() .find_map(|(i, (input_field, expected_field))| { - (input_field.metadata() != expected_field.metadata()).then_some(( - i, - input_field, - expected_field, - "metadata", - )) + if input_field.data_type() != expected_field.data_type() { + Some((i, input_field, expected_field, "data type")) + } else if input_field.is_nullable() && !expected_field.is_nullable() { + Some((i, input_field, expected_field, "nullability")) + } else if input_field.metadata() != expected_field.metadata() { + Some((i, input_field, expected_field, "metadata")) + } else { + None + } }) { return plan_err!( @@ -184,29 +152,6 @@ pub fn project_plan_to_schema( ); } - if let Some((i, input_field, expected_field)) = input_schema - .fields() - .iter() - .zip(expected_schema.fields().iter()) - .enumerate() - .find_map(|(i, (input_field, expected_field))| { - (input_field.is_nullable() && !expected_field.is_nullable()).then_some(( - i, - input_field, - expected_field, - )) - }) - { - return plan_err!( - "Cannot project plan column {i} ('{}') to expected output field '{}': \ - field nullability differs (input field: {:?}, expected field: {:?})", - input_field.name(), - expected_field.name(), - input_field, - expected_field - ); - } - let projection_exprs = expected_schema .fields() .iter() @@ -235,173 +180,6 @@ pub fn project_plan_to_schema( Ok(Arc::new(projection)) } -fn validate_schema_alignment( - input_schema: &SchemaRef, - expected_schema: &SchemaRef, - operation: &str, -) -> Result<()> { - if input_schema.fields().len() != expected_schema.fields().len() { - return plan_err!( - "Cannot {operation} plan to expected schema: expected {} column(s), got {}", - expected_schema.fields().len(), - input_schema.fields().len() - ); - } - - if let Some((i, input_field, expected_field, mismatch)) = input_schema - .fields() - .iter() - .zip(expected_schema.fields().iter()) - .enumerate() - .find_map(|(i, (input_field, expected_field))| { - if input_field.data_type() != expected_field.data_type() { - Some((i, input_field, expected_field, "data type")) - } else { - None - } - }) - { - return plan_err!( - "Cannot {operation} plan column {i} ('{}') to expected output field '{}': \ - field {mismatch} differs (input field: {:?}, expected field: {:?})", - input_field.name(), - expected_field.name(), - input_field, - expected_field - ); - } - - Ok(()) -} - -/// Plan-time schema adapter for positional schema alignment. -/// -/// [`ProjectionExec`] cannot express every schema-only alignment. In particular, a column -/// expression remains nullable when its input field is nullable, so projection cannot advertise -/// a non-null expected field. This adapter is for cases where the operator-level contract has -/// already established that columns are positionally compatible and the child plan must expose -/// the declared schema exactly. -#[derive(Debug, Clone)] -pub struct SchemaAlignExec { - input: Arc, - schema: SchemaRef, - cache: Arc, -} - -impl SchemaAlignExec { - /// Create a new schema alignment adapter. - pub fn try_new(input: Arc, schema: SchemaRef) -> Result { - validate_schema_alignment(&input.schema(), &schema, "align")?; - - let input_properties = input.properties(); - let partitioning = match &input_properties.partitioning { - Partitioning::RoundRobinBatch(partitions) => { - Partitioning::RoundRobinBatch(*partitions) - } - partitioning => { - Partitioning::UnknownPartitioning(partitioning.partition_count()) - } - }; - let properties = PlanProperties::new( - EquivalenceProperties::new(Arc::clone(&schema)), - partitioning, - input_properties.emission_type, - input_properties.boundedness, - ) - .with_evaluation_type(input_properties.evaluation_type) - .with_scheduling_type(input_properties.scheduling_type); - - Ok(Self { - input, - schema, - cache: Arc::new(properties), - }) - } - - /// Input plan being aligned. - pub fn input(&self) -> &Arc { - &self.input - } -} - -impl DisplayAs for SchemaAlignExec { - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - match t { - DisplayFormatType::Default | DisplayFormatType::Verbose => { - write!(f, "SchemaAlignExec") - } - DisplayFormatType::TreeRender => Ok(()), - } - } -} - -impl ExecutionPlan for SchemaAlignExec { - fn name(&self) -> &'static str { - "SchemaAlignExec" - } - - fn properties(&self) -> &Arc { - &self.cache - } - - fn children(&self) -> Vec<&Arc> { - vec![&self.input] - } - - fn apply_expressions( - &self, - _f: &mut dyn FnMut(&dyn PhysicalExpr) -> Result, - ) -> Result { - Ok(TreeNodeRecursion::Continue) - } - - fn maintains_input_order(&self) -> Vec { - vec![true] - } - - fn with_new_children( - self: Arc, - children: Vec>, - ) -> Result> { - let [input] = children.try_into().map_err(|children: Vec<_>| { - datafusion_common::DataFusionError::Internal(format!( - "SchemaAlignExec expected 1 child, got {}", - children.len() - )) - })?; - Ok(Arc::new(Self::try_new(input, Arc::clone(&self.schema))?)) - } - - fn execute( - &self, - partition: usize, - context: Arc, - ) -> Result { - let schema = Arc::clone(&self.schema); - let stream = self.input.execute(partition, context)?.map({ - let schema = Arc::clone(&schema); - move |batch| { - let batch = batch?; - if batch.schema().as_ref() == schema.as_ref() { - Ok(batch) - } else { - RecordBatch::try_new(Arc::clone(&schema), batch.columns().to_vec()) - .map_err(Into::into) - } - } - }); - Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream))) - } - - fn partition_statistics(&self, partition: Option) -> Result> { - self.input.partition_statistics(partition) - } -} - /// If running in a tokio context spawns the execution of `stream` to a separate task /// allowing it to execute in parallel with an intermediate buffer of size `buffer` pub fn spawn_buffered( @@ -531,40 +309,6 @@ mod tests { Arc::new(EmptyExec::new(Arc::new(Schema::new(fields)))) } - fn single_field_schema(name: &str, data_type: DataType, nullable: bool) -> SchemaRef { - Arc::new(Schema::new(vec![Field::new(name, data_type, nullable)])) - } - - fn single_i32_exec(name: &str, nullable: bool) -> Arc { - empty_exec(vec![Field::new(name, DataType::Int32, nullable)]) - } - - fn field_metadata_mismatch() -> (Arc, SchemaRef) { - let input = - empty_exec(vec![Field::new("a", DataType::Int32, false).with_metadata( - HashMap::from([("source".to_string(), "input".to_string())]), - )]); - let expected_schema = Arc::new(Schema::new(vec![ - Field::new("renamed", DataType::Int32, false).with_metadata(HashMap::from([ - ("source".to_string(), "expected".to_string()), - ])), - ])); - (input, expected_schema) - } - - fn schema_metadata_mismatch() -> (Arc, SchemaRef) { - let input_schema = Arc::new(Schema::new_with_metadata( - vec![Field::new("a", DataType::Int32, false)], - HashMap::from([("source".to_string(), "input".to_string())]), - )); - let input: Arc = Arc::new(EmptyExec::new(input_schema)); - let expected_schema = Arc::new(Schema::new_with_metadata( - vec![Field::new("renamed", DataType::Int32, false)], - HashMap::from([("source".to_string(), "expected".to_string())]), - )); - (input, expected_schema) - } - #[test] fn test_compute_record_batch_statistics_empty() -> Result<()> { let schema = Arc::new(Schema::new(vec![ @@ -666,7 +410,11 @@ mod tests { #[test] fn project_plan_to_schema_returns_input_when_schema_matches() -> Result<()> { - let schema = single_field_schema("value", DataType::Int32, false); + let schema = Arc::new(Schema::new(vec![Field::new( + "value", + DataType::Int32, + false, + )])); let input: Arc = Arc::new(EmptyExec::new(Arc::clone(&schema))); let result = project_plan_to_schema(Arc::clone(&input), &schema)?; @@ -727,7 +475,7 @@ mod tests { #[test] fn project_plan_to_schema_errors_on_column_count_mismatch() { - let input = single_i32_exec("a", false); + let input = empty_exec(vec![Field::new("a", DataType::Int32, false)]); let expected_schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Int32, false), Field::new("b", DataType::Int32, false), @@ -739,8 +487,9 @@ mod tests { #[test] fn project_plan_to_schema_errors_on_type_mismatch() { - let input = single_i32_exec("a", false); - let expected_schema = single_field_schema("a", DataType::Float32, false); + let input = empty_exec(vec![Field::new("a", DataType::Int32, false)]); + let expected_schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, false)])); let err = project_plan_to_schema(input, &expected_schema).unwrap_err(); assert!(err.to_string().contains("field data type differs")); @@ -748,8 +497,12 @@ mod tests { #[test] fn project_plan_to_schema_widens_nullability() -> Result<()> { - let input = single_i32_exec("a", false); - let expected_schema = single_field_schema("renamed", DataType::Int32, true); + let input = empty_exec(vec![Field::new("a", DataType::Int32, false)]); + let expected_schema = Arc::new(Schema::new(vec![Field::new( + "renamed", + DataType::Int32, + true, + )])); let result = project_plan_to_schema(input, &expected_schema)?; @@ -759,106 +512,44 @@ mod tests { #[test] fn project_plan_to_schema_errors_on_nullability_narrowing() { - let input = single_i32_exec("a", true); - let expected_schema = single_field_schema("renamed", DataType::Int32, false); + let input = empty_exec(vec![Field::new("a", DataType::Int32, true)]); + let expected_schema = Arc::new(Schema::new(vec![Field::new( + "renamed", + DataType::Int32, + false, + )])); let err = project_plan_to_schema(input, &expected_schema).unwrap_err(); assert!(err.to_string().contains("field nullability differs")); } #[test] - fn align_plan_to_schema_returns_input_when_schema_matches() -> Result<()> { - let schema = single_field_schema("value", DataType::Int32, false); - let input: Arc = Arc::new(EmptyExec::new(Arc::clone(&schema))); - - let result = align_plan_to_schema(Arc::clone(&input), &schema)?; - - assert!(Arc::ptr_eq(&input, &result)); - Ok(()) - } - - #[test] - fn align_plan_to_schema_uses_projection_for_rename_only() -> Result<()> { - let input = single_i32_exec("recursive_a", false); - let expected_schema = single_field_schema("a", DataType::Int32, false); - - let result = align_plan_to_schema(Arc::clone(&input), &expected_schema)?; - - let projection = result - .downcast_ref::() - .expect("rename-only alignment should use ProjectionExec"); - assert!(Arc::ptr_eq(projection.input(), &input)); - assert_eq!(projection.schema(), expected_schema); - Ok(()) - } - - #[test] - fn align_plan_to_schema_uses_adapter_for_nullability_narrowing() -> Result<()> { - let input = single_i32_exec("a", true); - let expected_schema = single_field_schema("renamed", DataType::Int32, false); - - let result = align_plan_to_schema(Arc::clone(&input), &expected_schema)?; - - let aligned = result - .downcast_ref::() - .expect("nullability narrowing should use SchemaAlignExec"); - assert!(Arc::ptr_eq(aligned.input(), &input)); - assert_eq!(aligned.schema(), expected_schema); - Ok(()) - } - - #[test] - fn align_plan_to_schema_errors_on_column_count_mismatch() { - let input = single_i32_exec("a", false); + fn project_plan_to_schema_errors_on_field_metadata_mismatch() { + let input = + empty_exec(vec![Field::new("a", DataType::Int32, false).with_metadata( + HashMap::from([("source".to_string(), "input".to_string())]), + )]); let expected_schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Int32, false), + Field::new("renamed", DataType::Int32, false).with_metadata(HashMap::from([ + ("source".to_string(), "expected".to_string()), + ])), ])); - let err = align_plan_to_schema(input, &expected_schema).unwrap_err(); - assert!(err.to_string().contains("expected 2 column")); - } - - #[test] - fn align_plan_to_schema_errors_on_type_mismatch() { - let input = single_i32_exec("a", false); - let expected_schema = single_field_schema("a", DataType::Float32, false); - - let err = align_plan_to_schema(input, &expected_schema).unwrap_err(); - assert!(err.to_string().contains("field data type differs")); - } - - #[test] - fn align_plan_to_schema_aligns_field_metadata() -> Result<()> { - let (input, expected_schema) = field_metadata_mismatch(); - - let result = align_plan_to_schema(input, &expected_schema)?; - - assert_eq!(result.schema(), expected_schema); - Ok(()) - } - - #[test] - fn align_plan_to_schema_aligns_schema_metadata() -> Result<()> { - let (input, expected_schema) = schema_metadata_mismatch(); - - let result = align_plan_to_schema(input, &expected_schema)?; - - assert_eq!(result.schema(), expected_schema); - Ok(()) - } - - #[test] - fn project_plan_to_schema_errors_on_field_metadata_mismatch() { - let (input, expected_schema) = field_metadata_mismatch(); - let err = project_plan_to_schema(input, &expected_schema).unwrap_err(); assert!(err.to_string().contains("field metadata differs")); } #[test] fn project_plan_to_schema_errors_on_schema_metadata_mismatch() { - let (input, expected_schema) = schema_metadata_mismatch(); + let input_schema = Arc::new(Schema::new_with_metadata( + vec![Field::new("a", DataType::Int32, false)], + HashMap::from([("source".to_string(), "input".to_string())]), + )); + let input: Arc = Arc::new(EmptyExec::new(input_schema)); + let expected_schema = Arc::new(Schema::new_with_metadata( + vec![Field::new("renamed", DataType::Int32, false)], + HashMap::from([("source".to_string(), "expected".to_string())]), + )); let err = project_plan_to_schema(input, &expected_schema).unwrap_err(); assert!(err.to_string().contains("schema metadata differ")); diff --git a/datafusion/physical-plan/src/recursive_query.rs b/datafusion/physical-plan/src/recursive_query.rs index 2bb32c39cf3d8..43c0d6d068f58 100644 --- a/datafusion/physical-plan/src/recursive_query.rs +++ b/datafusion/physical-plan/src/recursive_query.rs @@ -24,7 +24,7 @@ use std::task::{Context, Poll}; use super::work_table::{ReservedBatches, WorkTable}; use crate::aggregates::group_values::{GroupValues, new_group_values}; use crate::aggregates::order::GroupOrdering; -use crate::common::align_plan_to_schema; +use crate::common::project_plan_to_schema; use crate::execution_plan::{Boundedness, EmissionType, reset_plan_states}; use crate::metrics::{ BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RecordOutput, @@ -42,7 +42,6 @@ use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{Result, internal_datafusion_err, not_impl_err}; use datafusion_execution::TaskContext; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; -use datafusion_expr::expr::intersect_metadata_for_union; use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; @@ -82,7 +81,13 @@ pub struct RecursiveQueryExec { } impl RecursiveQueryExec { - /// Create a new RecursiveQueryExec + /// Create a new [`RecursiveQueryExec`] deriving the output schema from the + /// physical children. + /// + /// This constructor is retained for backward compatibility. Planner-created + /// recursive CTEs should use [`Self::try_new_with_schema`] with the logical + /// recursive CTE schema, which can widen nullability across the static and + /// recursive terms. pub fn try_new( name: String, static_term: Arc, @@ -91,17 +96,68 @@ impl RecursiveQueryExec { ) -> Result { // Each recursive query needs its own work table let work_table = Arc::new(WorkTable::new(name.clone())); - // Use static term field names with nullability widened across static and - // recursive terms. Align both children at plan construction time instead - // of patching batches in RecursiveQueryStream. + // Preserve the legacy constructor behavior by deriving the output schema + // from the physical children. let recursive_term = assign_work_table(recursive_term, &work_table)?; let output_schema = recursive_query_output_schema( &static_term.schema(), &recursive_term.schema(), )?; - let static_term = align_plan_to_schema(static_term, &output_schema)?; - let recursive_term = align_plan_to_schema(recursive_term, &output_schema)?; - let cache = Self::compute_properties(Arc::clone(&output_schema)); + let static_term = project_plan_to_schema(static_term, &output_schema)?; + let recursive_term = project_plan_to_schema(recursive_term, &output_schema)?; + Self::try_new_with_work_table( + name, + work_table, + static_term, + recursive_term, + &output_schema, + is_distinct, + ) + } + + /// Create a new [`RecursiveQueryExec`] with an explicit output schema. + /// + /// The supplied `output_schema` is authoritative. Both the static term and + /// recursive term are aligned to this schema at plan construction time. Use + /// this constructor when the logical recursive CTE schema is known. + /// + /// Recursive CTE schema contract: + /// + /// * field names come from the static term; + /// * data types must be compatible across static and recursive terms; + /// * nullability is widened across both terms; + /// * metadata must remain consistent with the logical schema. + pub fn try_new_with_schema( + name: String, + static_term: Arc, + recursive_term: Arc, + output_schema: &SchemaRef, + is_distinct: bool, + ) -> Result { + // Each recursive query needs its own work table + let work_table = Arc::new(WorkTable::new(name.clone())); + let recursive_term = assign_work_table(recursive_term, &work_table)?; + let static_term = project_plan_to_schema(static_term, output_schema)?; + let recursive_term = project_plan_to_schema(recursive_term, output_schema)?; + Self::try_new_with_work_table( + name, + work_table, + static_term, + recursive_term, + output_schema, + is_distinct, + ) + } + + fn try_new_with_work_table( + name: String, + work_table: Arc, + static_term: Arc, + recursive_term: Arc, + output_schema: &SchemaRef, + is_distinct: bool, + ) -> Result { + let cache = Self::compute_properties(Arc::clone(output_schema)); Ok(RecursiveQueryExec { name, static_term, @@ -187,10 +243,12 @@ impl ExecutionPlan for RecursiveQueryExec { self: Arc, children: Vec>, ) -> Result> { - RecursiveQueryExec::try_new( + let output_schema = self.schema(); + RecursiveQueryExec::try_new_with_schema( self.name.clone(), Arc::clone(&children[0]), Arc::clone(&children[1]), + &output_schema, self.is_distinct, ) .map(|e| Arc::new(e) as _) @@ -406,20 +464,14 @@ fn recursive_query_output_schema( static_field.data_type().clone(), static_field.is_nullable() || recursive_field.is_nullable(), ) - .with_metadata(intersect_metadata_for_union([ - static_field.metadata(), - recursive_field.metadata(), - ])), + .with_metadata(static_field.metadata().clone()), )) }) .collect::>>()?; Ok(Arc::new(Schema::new_with_metadata( fields, - intersect_metadata_for_union([ - static_schema.metadata(), - recursive_schema.metadata(), - ]), + static_schema.metadata().clone(), ))) } @@ -575,6 +627,23 @@ mod tests { ) } + #[test] + fn recursive_query_exec_try_new_keeps_backward_compatible_default() -> Result<()> { + let static_term = empty_exec(vec![Field::new("value", DataType::Int32, false)]); + let recursive_term = + empty_exec(vec![Field::new("value", DataType::Int32, false)]); + + let exec = RecursiveQueryExec::try_new( + "numbers".to_string(), + Arc::clone(&static_term), + recursive_term, + false, + )?; + + assert_eq!(exec.schema(), static_term.schema()); + Ok(()) + } + #[test] fn recursive_query_exec_projects_recursive_term_to_reconciled_schema() -> Result<()> { let static_term = empty_exec(vec![Field::new("value", DataType::Int32, false)]); @@ -616,7 +685,35 @@ mod tests { } #[test] - fn recursive_query_exec_intersects_output_metadata() -> Result<()> { + fn recursive_query_exec_with_schema_uses_declared_output_schema() -> Result<()> { + let static_term = empty_exec(vec![Field::new("anchor", DataType::Int32, false)]); + let recursive_term = empty_exec(vec![Field::new( + "anchor + Int32(1)", + DataType::Int32, + false, + )]); + let output_schema = Arc::new(Schema::new(vec![Field::new( + "declared", + DataType::Int32, + true, + )])); + + let exec = RecursiveQueryExec::try_new_with_schema( + "numbers".to_string(), + static_term, + recursive_term, + &output_schema, + false, + )?; + + assert_eq!(exec.schema(), output_schema); + assert_eq!(exec.static_term().schema(), output_schema); + assert_eq!(exec.recursive_term().schema(), output_schema); + Ok(()) + } + + #[test] + fn recursive_query_exec_preserves_static_output_metadata() -> Result<()> { let static_field = Field::new("value", DataType::Int32, false).with_metadata(HashMap::from([ ("shared".to_string(), "same".to_string()), @@ -625,7 +722,7 @@ mod tests { let recursive_field = Field::new("value", DataType::Int32, false).with_metadata(HashMap::from([ ("shared".to_string(), "same".to_string()), - ("static".to_string(), "different".to_string()), + ("static".to_string(), "only".to_string()), ])); let static_schema = Arc::new(Schema::new_with_metadata( vec![static_field], @@ -638,7 +735,7 @@ mod tests { vec![recursive_field], HashMap::from([ ("shared".to_string(), "same".to_string()), - ("static".to_string(), "different".to_string()), + ("static".to_string(), "only".to_string()), ]), )); @@ -649,11 +746,17 @@ mod tests { assert_eq!( exec.schema().field(0).metadata(), - &HashMap::from([("shared".to_string(), "same".to_string())]) + &HashMap::from([ + ("shared".to_string(), "same".to_string()), + ("static".to_string(), "only".to_string()), + ]) ); assert_eq!( exec.schema().metadata(), - &HashMap::from([("shared".to_string(), "same".to_string())]) + &HashMap::from([ + ("shared".to_string(), "same".to_string()), + ("static".to_string(), "only".to_string()), + ]) ); Ok(()) } From ec9aafd52af43d73a84bbf23c419a26b06e97803 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Mon, 11 May 2026 16:28:15 +0800 Subject: [PATCH 15/24] feat: enhance documentation for name preservation and nullability rationale - Added two comments in `plan.rs` to clarify the name-preservation invariant and nullability-widening rationale at the construction site. - Updated documentation in `recursive_query.rs` to note that `output_schema` is pre-widened, ensuring safe direction for recursive CTEs. - Introduced a new query in `cte.slt` to test distinct column aliases, reinforcing the invariant that the CTE's exposed column name comes from the anchor term. --- datafusion/expr/src/logical_plan/plan.rs | 5 +++++ datafusion/physical-plan/src/recursive_query.rs | 5 +++++ datafusion/sqllogictest/test_files/cte.slt | 14 ++++++++++++++ 3 files changed, 24 insertions(+) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index d3ba6635910c6..930ab47444ef7 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -2325,10 +2325,15 @@ fn recursive_query_schema( recursive_field.data_type() ); } + // Field names and qualifiers always come from the static/anchor term so + // that recursive-term column names never leak into the declared CTE schema. let (qualifier, _) = static_schema.qualified_field(i); let field = Field::new( static_field.name(), static_field.data_type().clone(), + // Nullability is widened (union-like) across both terms so that a + // nullable recursive expression does not force a runtime error when + // the anchor is non-nullable (e.g. `SELECT 0 AS level`). static_field.is_nullable() || recursive_field.is_nullable(), ) .with_metadata(static_field.metadata().clone()); diff --git a/datafusion/physical-plan/src/recursive_query.rs b/datafusion/physical-plan/src/recursive_query.rs index 43c0d6d068f58..a0c5384d7f068 100644 --- a/datafusion/physical-plan/src/recursive_query.rs +++ b/datafusion/physical-plan/src/recursive_query.rs @@ -137,6 +137,11 @@ impl RecursiveQueryExec { // Each recursive query needs its own work table let work_table = Arc::new(WorkTable::new(name.clone())); let recursive_term = assign_work_table(recursive_term, &work_table)?; + // `output_schema` has already been widened for nullability at logical + // planning time, so `project_plan_to_schema` here only ever widens + // non-nullable → nullable (safe cast) or encounters equal-nullability + // fields. It will never be asked to narrow a nullable input to a + // non-null expected field for a recursive CTE child. let static_term = project_plan_to_schema(static_term, output_schema)?; let recursive_term = project_plan_to_schema(recursive_term, output_schema)?; Self::try_new_with_work_table( diff --git a/datafusion/sqllogictest/test_files/cte.slt b/datafusion/sqllogictest/test_files/cte.slt index e16f96cdd44b7..fd19e03ef4374 100644 --- a/datafusion/sqllogictest/test_files/cte.slt +++ b/datafusion/sqllogictest/test_files/cte.slt @@ -1313,6 +1313,20 @@ SELECT * FROM t; 0 NULL +# Column names in a recursive CTE must come from the anchor/static term, +# not from the recursive term, even when the recursive term uses a different alias. +query I +WITH RECURSIVE t AS ( + SELECT 0 AS anchor_col + UNION ALL + SELECT anchor_col + 1 AS recursive_col FROM t WHERE anchor_col < 2 +) +SELECT anchor_col FROM t; +---- +0 +1 +2 + statement count 0 set datafusion.execution.enable_recursive_ctes = false; From 59e8d922ac6d7dc6a67169b10e8e5d9642b08c0e Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Wed, 13 May 2026 21:32:11 +0800 Subject: [PATCH 16/24] Revert to 739e1471b: Add reusable plan-time schema alignment helper and apply to RecursiveQueryExec (#21912) --- datafusion/expr/src/logical_plan/builder.rs | 65 +---- datafusion/expr/src/logical_plan/plan.rs | 187 +------------ datafusion/expr/src/logical_plan/tree_node.rs | 10 +- .../physical-plan/src/recursive_query.rs | 252 ++---------------- datafusion/proto/src/logical_plan/mod.rs | 12 +- datafusion/sql/src/cte.rs | 63 ++--- datafusion/sqllogictest/test_files/cte.slt | 29 +- 7 files changed, 71 insertions(+), 547 deletions(-) diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 4c97c2e9b5d1b..017a123eb035b 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -192,14 +192,12 @@ impl LogicalPlanBuilder { // Ensure that the recursive term has the same field types as the static term let coerced_recursive_term = coerce_plan_expr_for_schema(recursive_term, self.plan.schema())?; - Ok(Self::from(LogicalPlan::RecursiveQuery( - RecursiveQuery::try_new( - name, - self.plan, - Arc::new(coerced_recursive_term), - is_distinct, - )?, - ))) + Ok(Self::from(LogicalPlan::RecursiveQuery(RecursiveQuery { + name, + static_term: self.plan, + recursive_term: Arc::new(coerced_recursive_term), + is_distinct, + }))) } /// Create a values list based relation, and the schema is inferred from data, consuming @@ -2291,7 +2289,6 @@ pub fn unnest_with_options( #[cfg(test)] mod tests { - use std::collections::HashMap; use std::vec; use super::*; @@ -2370,56 +2367,6 @@ mod tests { Ok(()) } - #[test] - fn recursive_query_schema_widens_nullability_from_recursive_term() -> Result<()> { - let static_term = - LogicalPlanBuilder::empty(true).project(vec![lit(0i32).alias("n")])?; - let recursive_term = LogicalPlanBuilder::empty(true) - .project(vec![lit(ScalarValue::Int32(None)).alias("recursive_n")])? - .build()?; - - let plan = static_term - .to_recursive_query("t".to_string(), recursive_term, false)? - .build()?; - - assert_eq!(plan.schema().field(0).name(), "n"); - assert!(plan.schema().field(0).is_nullable()); - Ok(()) - } - - #[test] - fn recursive_query_schema_preserves_static_metadata() -> Result<()> { - let static_metadata = - HashMap::from([("source".to_string(), "static".to_string())]); - let recursive_metadata = - HashMap::from([("source".to_string(), "recursive".to_string())]); - let static_schema = Schema::new_with_metadata( - vec![ - Field::new("n", DataType::Int32, false) - .with_metadata(static_metadata.clone()), - ], - static_metadata.clone(), - ); - let recursive_schema = Schema::new_with_metadata( - vec![ - Field::new("recursive_n", DataType::Int32, false) - .with_metadata(recursive_metadata), - ], - HashMap::from([("source".to_string(), "recursive".to_string())]), - ); - - let static_term = table_scan(Some("static_t"), &static_schema, None)?; - let recursive_term = - table_scan(Some("recursive_t"), &recursive_schema, None)?.build()?; - let plan = static_term - .to_recursive_query("t".to_string(), recursive_term, false)? - .build()?; - - assert_eq!(plan.schema().field(0).metadata(), &static_metadata); - assert_eq!(plan.schema().metadata(), &static_metadata); - Ok(()) - } - #[test] fn plan_builder_union() -> Result<()> { let plan = diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 930ab47444ef7..db8b82fe87a14 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -354,6 +354,7 @@ impl LogicalPlan { LogicalPlan::Ddl(ddl) => ddl.schema(), LogicalPlan::Unnest(Unnest { schema, .. }) => schema, LogicalPlan::RecursiveQuery(RecursiveQuery { static_term, .. }) => { + // we take the schema of the static term as the schema of the entire recursive query static_term.schema() } } @@ -1079,12 +1080,12 @@ impl LogicalPlan { }) => { self.assert_no_expressions(expr)?; let (static_term, recursive_term) = self.only_two_inputs(inputs)?; - Ok(LogicalPlan::RecursiveQuery(RecursiveQuery::try_new( - name.clone(), - Arc::new(static_term), - Arc::new(recursive_term), - *is_distinct, - )?)) + Ok(LogicalPlan::RecursiveQuery(RecursiveQuery { + name: name.clone(), + static_term: Arc::new(static_term), + recursive_term: Arc::new(recursive_term), + is_distinct: *is_distinct, + })) } LogicalPlan::Analyze(a) => { self.assert_no_expressions(expr)?; @@ -2245,7 +2246,7 @@ impl PartialOrd for EmptyRelation { /// intermediate table, then empty the intermediate table. /// /// [Postgres Docs]: https://www.postgresql.org/docs/current/queries-with.html#QUERIES-WITH-RECURSIVE -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub struct RecursiveQuery { /// Name of the query pub name: String, @@ -2259,112 +2260,6 @@ pub struct RecursiveQuery { pub is_distinct: bool, } -impl RecursiveQuery { - /// Create a recursive query with an output schema using static term field names - /// and nullability widened across both static and recursive terms. - pub fn try_new( - name: String, - static_term: Arc, - recursive_term: Arc, - is_distinct: bool, - ) -> Result { - let schema = - recursive_query_schema(static_term.schema(), recursive_term.schema())?; - let static_term = align_logical_plan_to_schema(static_term, schema)?; - Ok(Self { - name, - static_term, - recursive_term, - is_distinct, - }) - } -} - -fn align_logical_plan_to_schema( - input: Arc, - schema: DFSchemaRef, -) -> Result> { - if input.schema().as_ref() == schema.as_ref() { - return Ok(input); - } - - let expr = input - .schema() - .fields() - .iter() - .enumerate() - .map(|(i, _)| Expr::Column(Column::from(input.schema().qualified_field(i)))) - .collect(); - Ok(Arc::new(LogicalPlan::Projection( - Projection::try_new_with_schema(expr, input, schema)?, - ))) -} - -fn recursive_query_schema( - static_schema: &DFSchema, - recursive_schema: &DFSchema, -) -> Result { - if static_schema.fields().len() != recursive_schema.fields().len() { - return plan_err!( - "RecursiveQuery static and recursive terms have different number of columns: {} != {}", - static_schema.fields().len(), - recursive_schema.fields().len() - ); - } - - let fields = static_schema - .fields() - .iter() - .zip(recursive_schema.fields().iter()) - .enumerate() - .map(|(i, (static_field, recursive_field))| { - if static_field.data_type() != recursive_field.data_type() { - return plan_err!( - "RecursiveQuery column {i} has different types: static term has {} whereas recursive term has {}", - static_field.data_type(), - recursive_field.data_type() - ); - } - // Field names and qualifiers always come from the static/anchor term so - // that recursive-term column names never leak into the declared CTE schema. - let (qualifier, _) = static_schema.qualified_field(i); - let field = Field::new( - static_field.name(), - static_field.data_type().clone(), - // Nullability is widened (union-like) across both terms so that a - // nullable recursive expression does not force a runtime error when - // the anchor is non-nullable (e.g. `SELECT 0 AS level`). - static_field.is_nullable() || recursive_field.is_nullable(), - ) - .with_metadata(static_field.metadata().clone()); - Ok((qualifier.cloned(), Arc::new(field))) - }) - .collect::>>()?; - - Ok(Arc::new(DFSchema::new_with_metadata( - fields, - static_schema.metadata().clone(), - )?)) -} - -impl PartialOrd for RecursiveQuery { - fn partial_cmp(&self, other: &Self) -> Option { - ( - &self.name, - &self.static_term, - &self.recursive_term, - self.is_distinct, - ) - .partial_cmp(&( - &other.name, - &other.static_term, - &other.recursive_term, - other.is_distinct, - )) - .filter(|cmp| *cmp != Ordering::Equal || self == other) - } -} - /// Values expression. See /// [Postgres VALUES](https://www.postgresql.org/docs/current/queries-values.html) /// documentation for more details. @@ -4976,72 +4871,6 @@ mod tests { ); } - fn empty_plan_with_fields(fields: Vec) -> Arc { - Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { - produce_one_row: false, - schema: Arc::new( - DFSchema::from_unqualified_fields(fields.into(), HashMap::new()).unwrap(), - ), - })) - } - - #[test] - fn recursive_query_try_new_aligns_static_term_to_widened_schema() -> Result<()> { - let static_term = - empty_plan_with_fields(vec![Field::new("a", DataType::Int32, false)]); - let recursive_term = - empty_plan_with_fields(vec![Field::new("b", DataType::Int32, true)]); - - let query = RecursiveQuery::try_new( - "t".to_string(), - Arc::clone(&static_term), - Arc::clone(&recursive_term), - false, - )?; - - assert_eq!(query.static_term.schema().field(0).name(), "a"); - assert!(query.static_term.schema().field(0).is_nullable()); - assert!(matches!( - query.static_term.as_ref(), - LogicalPlan::Projection(_) - )); - assert!( - Arc::ptr_eq(&query.recursive_term, &recursive_term), - "recursive term should not be wrapped in a schema-only Projection" - ); - Ok(()) - } - - #[test] - fn recursive_query_try_new_rejects_mismatched_column_count() { - let static_term = - empty_plan_with_fields(vec![Field::new("a", DataType::Int32, false)]); - let recursive_term = empty_plan_with_fields(vec![ - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Int32, false), - ]); - - let err = - RecursiveQuery::try_new("t".to_string(), static_term, recursive_term, false) - .unwrap_err(); - - assert_snapshot!(err.strip_backtrace(), @"Error during planning: RecursiveQuery static and recursive terms have different number of columns: 1 != 2"); - } - - #[test] - fn recursive_query_try_new_rejects_mismatched_types() { - let static_term = - empty_plan_with_fields(vec![Field::new("a", DataType::Int32, false)]); - let recursive_term = - empty_plan_with_fields(vec![Field::new("a", DataType::Int64, false)]); - - let err = - RecursiveQuery::try_new("t".to_string(), static_term, recursive_term, false) - .unwrap_err(); - - assert_snapshot!(err.strip_backtrace(), @"Error during planning: RecursiveQuery column 0 has different types: static term has Int32 whereas recursive term has Int64"); - } - #[test] fn test_partial_eq_hash_and_partial_ord() { let empty_values = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index 0e54e4536439c..ef9382a57209a 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -329,18 +329,16 @@ impl TreeNode for LogicalPlan { static_term, recursive_term, is_distinct, - .. - }) => (static_term, recursive_term).map_elements(f)?.map_data( + }) => (static_term, recursive_term).map_elements(f)?.update_data( |(static_term, recursive_term)| { - RecursiveQuery::try_new( + LogicalPlan::RecursiveQuery(RecursiveQuery { name, static_term, recursive_term, is_distinct, - ) - .map(LogicalPlan::RecursiveQuery) + }) }, - )?, + ), LogicalPlan::Statement(stmt) => match stmt { Statement::Prepare(p) => p .input diff --git a/datafusion/physical-plan/src/recursive_query.rs b/datafusion/physical-plan/src/recursive_query.rs index a0c5384d7f068..c160f9a0dc763 100644 --- a/datafusion/physical-plan/src/recursive_query.rs +++ b/datafusion/physical-plan/src/recursive_query.rs @@ -81,13 +81,7 @@ pub struct RecursiveQueryExec { } impl RecursiveQueryExec { - /// Create a new [`RecursiveQueryExec`] deriving the output schema from the - /// physical children. - /// - /// This constructor is retained for backward compatibility. Planner-created - /// recursive CTEs should use [`Self::try_new_with_schema`] with the logical - /// recursive CTE schema, which can widen nullability across the static and - /// recursive terms. + /// Create a new RecursiveQueryExec pub fn try_new( name: String, static_term: Arc, @@ -96,73 +90,13 @@ impl RecursiveQueryExec { ) -> Result { // Each recursive query needs its own work table let work_table = Arc::new(WorkTable::new(name.clone())); - // Preserve the legacy constructor behavior by deriving the output schema - // from the physical children. - let recursive_term = assign_work_table(recursive_term, &work_table)?; - let output_schema = recursive_query_output_schema( - &static_term.schema(), - &recursive_term.schema(), - )?; + // Use the same work table for both the WorkTableExec and the recursive term + let output_schema = + recursive_output_schema(&static_term.schema(), &recursive_term.schema()); let static_term = project_plan_to_schema(static_term, &output_schema)?; - let recursive_term = project_plan_to_schema(recursive_term, &output_schema)?; - Self::try_new_with_work_table( - name, - work_table, - static_term, - recursive_term, - &output_schema, - is_distinct, - ) - } - - /// Create a new [`RecursiveQueryExec`] with an explicit output schema. - /// - /// The supplied `output_schema` is authoritative. Both the static term and - /// recursive term are aligned to this schema at plan construction time. Use - /// this constructor when the logical recursive CTE schema is known. - /// - /// Recursive CTE schema contract: - /// - /// * field names come from the static term; - /// * data types must be compatible across static and recursive terms; - /// * nullability is widened across both terms; - /// * metadata must remain consistent with the logical schema. - pub fn try_new_with_schema( - name: String, - static_term: Arc, - recursive_term: Arc, - output_schema: &SchemaRef, - is_distinct: bool, - ) -> Result { - // Each recursive query needs its own work table - let work_table = Arc::new(WorkTable::new(name.clone())); let recursive_term = assign_work_table(recursive_term, &work_table)?; - // `output_schema` has already been widened for nullability at logical - // planning time, so `project_plan_to_schema` here only ever widens - // non-nullable → nullable (safe cast) or encounters equal-nullability - // fields. It will never be asked to narrow a nullable input to a - // non-null expected field for a recursive CTE child. - let static_term = project_plan_to_schema(static_term, output_schema)?; - let recursive_term = project_plan_to_schema(recursive_term, output_schema)?; - Self::try_new_with_work_table( - name, - work_table, - static_term, - recursive_term, - output_schema, - is_distinct, - ) - } - - fn try_new_with_work_table( - name: String, - work_table: Arc, - static_term: Arc, - recursive_term: Arc, - output_schema: &SchemaRef, - is_distinct: bool, - ) -> Result { - let cache = Self::compute_properties(Arc::clone(output_schema)); + let recursive_term = project_plan_to_schema(recursive_term, &output_schema)?; + let cache = Self::compute_properties(output_schema); Ok(RecursiveQueryExec { name, static_term, @@ -248,12 +182,10 @@ impl ExecutionPlan for RecursiveQueryExec { self: Arc, children: Vec>, ) -> Result> { - let output_schema = self.schema(); - RecursiveQueryExec::try_new_with_schema( + RecursiveQueryExec::try_new( self.name.clone(), Arc::clone(&children[0]), Arc::clone(&children[1]), - &output_schema, self.is_distinct, ) .map(|e| Arc::new(e) as _) @@ -438,46 +370,28 @@ impl RecursiveQueryStream { } } -fn recursive_query_output_schema( +fn recursive_output_schema( static_schema: &SchemaRef, recursive_schema: &SchemaRef, -) -> Result { - if static_schema.fields().len() != recursive_schema.fields().len() { - return datafusion_common::plan_err!( - "RecursiveQueryExec static and recursive terms have different number of columns: {} != {}", - static_schema.fields().len(), - recursive_schema.fields().len() - ); - } - +) -> SchemaRef { let fields = static_schema .fields() .iter() - .zip(recursive_schema.fields().iter()) - .enumerate() - .map(|(i, (static_field, recursive_field))| { - if static_field.data_type() != recursive_field.data_type() { - return datafusion_common::plan_err!( - "RecursiveQueryExec column {i} has different types: static term has {} whereas recursive term has {}", - static_field.data_type(), - recursive_field.data_type() - ); - } - Ok(Arc::new( - Field::new( - static_field.name(), - static_field.data_type().clone(), - static_field.is_nullable() || recursive_field.is_nullable(), - ) - .with_metadata(static_field.metadata().clone()), - )) + .zip(recursive_schema.fields()) + .map(|(static_field, recursive_field)| { + Field::new( + static_field.name(), + static_field.data_type().clone(), + static_field.is_nullable() || recursive_field.is_nullable(), + ) + .with_metadata(static_field.metadata().clone()) }) - .collect::>>()?; + .collect::>(); - Ok(Arc::new(Schema::new_with_metadata( + Arc::new(Schema::new_with_metadata( fields, static_schema.metadata().clone(), - ))) + )) } fn assign_work_table( @@ -610,53 +524,24 @@ mod tests { use crate::projection::ProjectionExec; use arrow::datatypes::{DataType, Field, Schema}; - use std::collections::HashMap; fn empty_exec(fields: Vec) -> Arc { - empty_exec_with_schema(Arc::new(Schema::new(fields))) - } - - fn empty_exec_with_schema(schema: SchemaRef) -> Arc { - Arc::new(EmptyExec::new(schema)) - } - - fn recursive_exec( - static_term: Arc, - recursive_term: Arc, - ) -> Result { - RecursiveQueryExec::try_new( - "numbers".to_string(), - static_term, - recursive_term, - false, - ) + Arc::new(EmptyExec::new(Arc::new(Schema::new(fields)))) } #[test] - fn recursive_query_exec_try_new_keeps_backward_compatible_default() -> Result<()> { + fn recursive_query_exec_projects_recursive_term_to_reconciled_schema() -> Result<()> { let static_term = empty_exec(vec![Field::new("value", DataType::Int32, false)]); let recursive_term = - empty_exec(vec![Field::new("value", DataType::Int32, false)]); + empty_exec(vec![Field::new("value + Int32(1)", DataType::Int32, false)]); let exec = RecursiveQueryExec::try_new( "numbers".to_string(), Arc::clone(&static_term), - recursive_term, + Arc::clone(&recursive_term), false, )?; - assert_eq!(exec.schema(), static_term.schema()); - Ok(()) - } - - #[test] - fn recursive_query_exec_projects_recursive_term_to_reconciled_schema() -> Result<()> { - let static_term = empty_exec(vec![Field::new("value", DataType::Int32, false)]); - let recursive_term = - empty_exec(vec![Field::new("value + Int32(1)", DataType::Int32, false)]); - - let exec = recursive_exec(Arc::clone(&static_term), Arc::clone(&recursive_term))?; - assert_eq!(exec.schema(), static_term.schema()); let projection = exec .recursive_term() @@ -669,100 +554,21 @@ mod tests { } #[test] - fn recursive_query_exec_widens_output_nullability_from_recursive_term() -> Result<()> - { + fn recursive_query_exec_reconciles_nullability() -> Result<()> { let static_term = empty_exec(vec![Field::new("value", DataType::Int32, false)]); let recursive_term = empty_exec(vec![Field::new("value + Int32(1)", DataType::Int32, true)]); - let exec = recursive_exec(Arc::clone(&static_term), Arc::clone(&recursive_term))?; - - let expected_schema = Arc::new(Schema::new(vec![Field::new( - "value", - DataType::Int32, - true, - )])); - assert_eq!(exec.schema(), expected_schema); - assert_eq!(exec.static_term().schema(), expected_schema); - assert_eq!(exec.recursive_term().schema(), expected_schema); - assert!(exec.schema().field(0).is_nullable()); - Ok(()) - } - - #[test] - fn recursive_query_exec_with_schema_uses_declared_output_schema() -> Result<()> { - let static_term = empty_exec(vec![Field::new("anchor", DataType::Int32, false)]); - let recursive_term = empty_exec(vec![Field::new( - "anchor + Int32(1)", - DataType::Int32, - false, - )]); - let output_schema = Arc::new(Schema::new(vec![Field::new( - "declared", - DataType::Int32, - true, - )])); - - let exec = RecursiveQueryExec::try_new_with_schema( + let exec = RecursiveQueryExec::try_new( "numbers".to_string(), static_term, recursive_term, - &output_schema, false, )?; - assert_eq!(exec.schema(), output_schema); - assert_eq!(exec.static_term().schema(), output_schema); - assert_eq!(exec.recursive_term().schema(), output_schema); - Ok(()) - } - - #[test] - fn recursive_query_exec_preserves_static_output_metadata() -> Result<()> { - let static_field = - Field::new("value", DataType::Int32, false).with_metadata(HashMap::from([ - ("shared".to_string(), "same".to_string()), - ("static".to_string(), "only".to_string()), - ])); - let recursive_field = - Field::new("value", DataType::Int32, false).with_metadata(HashMap::from([ - ("shared".to_string(), "same".to_string()), - ("static".to_string(), "only".to_string()), - ])); - let static_schema = Arc::new(Schema::new_with_metadata( - vec![static_field], - HashMap::from([ - ("shared".to_string(), "same".to_string()), - ("static".to_string(), "only".to_string()), - ]), - )); - let recursive_schema = Arc::new(Schema::new_with_metadata( - vec![recursive_field], - HashMap::from([ - ("shared".to_string(), "same".to_string()), - ("static".to_string(), "only".to_string()), - ]), - )); - - let exec = recursive_exec( - empty_exec_with_schema(static_schema), - empty_exec_with_schema(recursive_schema), - )?; - - assert_eq!( - exec.schema().field(0).metadata(), - &HashMap::from([ - ("shared".to_string(), "same".to_string()), - ("static".to_string(), "only".to_string()), - ]) - ); - assert_eq!( - exec.schema().metadata(), - &HashMap::from([ - ("shared".to_string(), "same".to_string()), - ("static".to_string(), "only".to_string()), - ]) - ); + assert!(exec.schema().field(0).is_nullable()); + assert!(exec.static_term().schema().field(0).is_nullable()); + assert!(exec.recursive_term().schema().field(0).is_nullable()); Ok(()) } } diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index b245ba235a964..7ae5cbeed3e53 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -1067,12 +1067,12 @@ impl AsLogicalPlan for LogicalPlanNode { ))? .try_into_logical_plan(ctx, extension_codec)?; - Ok(LogicalPlan::RecursiveQuery(RecursiveQuery::try_new( - recursive_query_node.name.clone(), - Arc::new(static_term), - Arc::new(recursive_term), - recursive_query_node.is_distinct, - )?)) + Ok(LogicalPlan::RecursiveQuery(RecursiveQuery { + name: recursive_query_node.name.clone(), + static_term: Arc::new(static_term), + recursive_term: Arc::new(recursive_term), + is_distinct: recursive_query_node.is_distinct, + })) } LogicalPlanType::CteWorkTableScan(cte_work_table_scan_node) => { let CteWorkTableScanNode { name, schema } = cte_work_table_scan_node; diff --git a/datafusion/sql/src/cte.rs b/datafusion/sql/src/cte.rs index 74c92258889cc..18766d7056355 100644 --- a/datafusion/sql/src/cte.rs +++ b/datafusion/sql/src/cte.rs @@ -19,7 +19,6 @@ use std::sync::Arc; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; -use arrow::datatypes::SchemaRef; use datafusion_common::{ Result, not_impl_err, plan_err, tree_node::{TreeNode, TreeNodeRecursion}, @@ -133,10 +132,19 @@ impl SqlToRel<'_, S> { // bound to. // ---------- Step 2: Create a temporary relation ------------------ - // Step 2.1: Create a temporary relation logical plan that will be used + // Step 2.1: Create a table source for the temporary relation + let work_table_source = self + .context_provider + .create_cte_work_table(cte_name, Arc::clone(static_plan.schema().inner()))?; + + // Step 2.2: Create a temporary relation logical plan that will be used // as the input to the recursive term - let (work_table_source, work_table_plan) = - self.cte_work_table_plan(cte_name, Arc::clone(static_plan.schema().inner()))?; + let work_table_plan = LogicalPlanBuilder::scan( + cte_name.to_string(), + Arc::clone(&work_table_source), + None, + )? + .build()?; let name = cte_name.to_string(); @@ -151,8 +159,7 @@ impl SqlToRel<'_, S> { // this uses the named_relation we inserted above to resolve the // relation. This ensures that the recursive term uses the named relation logical plan // and thus the 'continuance' physical plan as its input and source - let recursive_plan = - self.set_expr_to_plan(*right_expr.clone(), planner_context)?; + let recursive_plan = self.set_expr_to_plan(*right_expr, planner_context)?; // Check if the recursive term references the CTE itself, // if not, it is a non-recursive CTE @@ -169,47 +176,11 @@ impl SqlToRel<'_, S> { } // ---------- Step 4: Create the final plan ------------------ - // Step 4.1: Compile the final plan. The first plan only discovers the - // fixed recursive CTE output schema. Recursive CTE nullability is - // union-like, so the recursive term can widen the work table schema. - // Replan the recursive term with that widened schema so predicates such - // as `n IS NOT NULL` are not optimized using the anchor-only schema. + // Step 4.1: Compile the final plan let distinct = !Self::is_union_all(set_quantifier)?; - let initial_recursive_query = LogicalPlanBuilder::from(static_plan.clone()) - .to_recursive_query(name.clone(), recursive_plan.clone(), distinct)? - .build()?; - if initial_recursive_query.schema() != static_plan.schema() { - let (_, work_table_plan) = self.cte_work_table_plan( - cte_name, - Arc::clone(initial_recursive_query.schema().inner()), - )?; - planner_context.insert_cte(cte_name.to_string(), work_table_plan); - let recursive_plan = self.set_expr_to_plan(*right_expr, planner_context)?; - planner_context.remove_cte(cte_name); - LogicalPlanBuilder::from(static_plan) - .to_recursive_query(name, recursive_plan, distinct)? - .build() - } else { - planner_context.remove_cte(cte_name); - Ok(initial_recursive_query) - } - } - - fn cte_work_table_plan( - &self, - cte_name: &str, - schema: SchemaRef, - ) -> Result<(Arc, LogicalPlan)> { - let work_table_source = self - .context_provider - .create_cte_work_table(cte_name, schema)?; - let work_table_plan = LogicalPlanBuilder::scan( - cte_name.to_string(), - Arc::clone(&work_table_source), - None, - )? - .build()?; - Ok((work_table_source, work_table_plan)) + LogicalPlanBuilder::from(static_plan) + .to_recursive_query(name, recursive_plan, distinct)? + .build() } } diff --git a/datafusion/sqllogictest/test_files/cte.slt b/datafusion/sqllogictest/test_files/cte.slt index fd19e03ef4374..d13e0d4f085e9 100644 --- a/datafusion/sqllogictest/test_files/cte.slt +++ b/datafusion/sqllogictest/test_files/cte.slt @@ -699,7 +699,7 @@ WITH RECURSIVE region_sales AS ( SELECT s.salesperson_id AS salesperson_id, SUM(s.sale_amount) AS amount, - 0 as level + SUM(0) as level FROM sales s GROUP BY @@ -1300,33 +1300,6 @@ DROP TABLE cte_schema_reread; statement ok DROP TABLE cte_schema_records; -# Recursive CTE nullability is union-like: anchor names are preserved, -# but nullable recursive output widens the CTE output schema. -query I -WITH RECURSIVE t AS ( - SELECT 0 AS n - UNION ALL - SELECT CAST(NULL AS INT) AS n FROM t WHERE n IS NOT NULL -) -SELECT * FROM t; ----- -0 -NULL - -# Column names in a recursive CTE must come from the anchor/static term, -# not from the recursive term, even when the recursive term uses a different alias. -query I -WITH RECURSIVE t AS ( - SELECT 0 AS anchor_col - UNION ALL - SELECT anchor_col + 1 AS recursive_col FROM t WHERE anchor_col < 2 -) -SELECT anchor_col FROM t; ----- -0 -1 -2 - statement count 0 set datafusion.execution.enable_recursive_ctes = false; From 02e1abbfe878a73c3fd221931f4eab60a24fef18 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Wed, 13 May 2026 17:52:56 +0800 Subject: [PATCH 17/24] feat: add SLT repro for recursive CTE and update nullability handling - Added SLT repro in datafusion/sqllogictest/test_files/cte.slt - Fixed recursive CTE work-table nullability: - Work table schema is now conservatively nullable - RecursiveQuery now stores output schema - Schema nullability considers static OR recursive term - Proto deserialize now rebuilds via builder - Updated affected EXPLAIN expectations --- datafusion/expr/src/logical_plan/builder.rs | 37 ++++++++++++++++++- datafusion/expr/src/logical_plan/plan.rs | 35 +++++++++++++++--- datafusion/expr/src/logical_plan/tree_node.rs | 2 + datafusion/proto/src/logical_plan/mod.rs | 16 ++++---- datafusion/sql/src/cte.rs | 20 ++++++++-- datafusion/sqllogictest/test_files/cte.slt | 23 ++++++++++-- 6 files changed, 110 insertions(+), 23 deletions(-) diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 017a123eb035b..c07959ea7937d 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -66,6 +66,30 @@ use indexmap::IndexSet; /// Default table name for unnamed table pub const UNNAMED_TABLE: &str = "?table?"; +fn recursive_query_output_schema( + static_schema: &DFSchema, + recursive_schema: &DFSchema, +) -> Result { + let fields = static_schema + .iter() + .zip(recursive_schema.iter()) + .map(|((qualifier, static_field), (_, recursive_field))| { + let field = static_field + .as_ref() + .clone() + .with_nullable( + static_field.is_nullable() || recursive_field.is_nullable(), + ) + .into(); + (qualifier.cloned(), field) + }) + .collect::>(); + Ok(DFSchemaRef::new(DFSchema::new_with_metadata( + fields, + static_schema.metadata().clone(), + )?)) +} + /// Options for [`LogicalPlanBuilder`] #[derive(Default, Debug, Clone)] pub struct LogicalPlanBuilderOptions { @@ -192,11 +216,20 @@ impl LogicalPlanBuilder { // Ensure that the recursive term has the same field types as the static term let coerced_recursive_term = coerce_plan_expr_for_schema(recursive_term, self.plan.schema())?; + let output_schema = recursive_query_output_schema( + self.plan.schema(), + coerced_recursive_term.schema(), + )?; + let static_term = + coerce_plan_expr_for_schema(Arc::unwrap_or_clone(self.plan), &output_schema)?; + let recursive_term = + coerce_plan_expr_for_schema(coerced_recursive_term, &output_schema)?; Ok(Self::from(LogicalPlan::RecursiveQuery(RecursiveQuery { name, - static_term: self.plan, - recursive_term: Arc::new(coerced_recursive_term), + static_term: Arc::new(static_term), + recursive_term: Arc::new(recursive_term), is_distinct, + schema: output_schema, }))) } diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index db8b82fe87a14..4b6328424e350 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -353,10 +353,7 @@ impl LogicalPlan { LogicalPlan::Copy(CopyTo { output_schema, .. }) => output_schema, LogicalPlan::Ddl(ddl) => ddl.schema(), LogicalPlan::Unnest(Unnest { schema, .. }) => schema, - LogicalPlan::RecursiveQuery(RecursiveQuery { static_term, .. }) => { - // we take the schema of the static term as the schema of the entire recursive query - static_term.schema() - } + LogicalPlan::RecursiveQuery(RecursiveQuery { schema, .. }) => schema, } } @@ -1076,7 +1073,10 @@ impl LogicalPlan { Ok(LogicalPlan::Distinct(distinct)) } LogicalPlan::RecursiveQuery(RecursiveQuery { - name, is_distinct, .. + name, + is_distinct, + schema, + .. }) => { self.assert_no_expressions(expr)?; let (static_term, recursive_term) = self.only_two_inputs(inputs)?; @@ -1085,6 +1085,7 @@ impl LogicalPlan { static_term: Arc::new(static_term), recursive_term: Arc::new(recursive_term), is_distinct: *is_distinct, + schema: Arc::clone(schema), })) } LogicalPlan::Analyze(a) => { @@ -2246,7 +2247,7 @@ impl PartialOrd for EmptyRelation { /// intermediate table, then empty the intermediate table. /// /// [Postgres Docs]: https://www.postgresql.org/docs/current/queries-with.html#QUERIES-WITH-RECURSIVE -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct RecursiveQuery { /// Name of the query pub name: String, @@ -2258,6 +2259,28 @@ pub struct RecursiveQuery { /// Should the output of the recursive term be deduplicated (`UNION`) or /// not (`UNION ALL`). pub is_distinct: bool, + /// The output schema of the recursive query. + pub schema: DFSchemaRef, +} + +// Manual implementation needed because of `schema` field. Comparison excludes this field. +impl PartialOrd for RecursiveQuery { + fn partial_cmp(&self, other: &Self) -> Option { + self.name + .partial_cmp(&other.name) + .and_then(|ordering| match ordering { + Ordering::Equal => self.static_term.partial_cmp(&other.static_term), + _ => Some(ordering), + }) + .and_then(|ordering| match ordering { + Ordering::Equal => self.recursive_term.partial_cmp(&other.recursive_term), + _ => Some(ordering), + }) + .and_then(|ordering| match ordering { + Ordering::Equal => self.is_distinct.partial_cmp(&other.is_distinct), + _ => Some(ordering), + }) + } } /// Values expression. See diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index ef9382a57209a..073f7fe7537ae 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -329,6 +329,7 @@ impl TreeNode for LogicalPlan { static_term, recursive_term, is_distinct, + schema, }) => (static_term, recursive_term).map_elements(f)?.update_data( |(static_term, recursive_term)| { LogicalPlan::RecursiveQuery(RecursiveQuery { @@ -336,6 +337,7 @@ impl TreeNode for LogicalPlan { static_term, recursive_term, is_distinct, + schema, }) }, ), diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 8228e8e6f2ff0..4b282129357ab 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -56,8 +56,7 @@ use datafusion_datasource_json::file_format::{ #[cfg(feature = "parquet")] use datafusion_datasource_parquet::file_format::{ParquetFormat, ParquetFormatFactory}; use datafusion_expr::{ - AggregateUDF, DmlStatement, FetchType, HigherOrderUDF, RecursiveQuery, SkipType, - TableSource, Unnest, + AggregateUDF, DmlStatement, FetchType, HigherOrderUDF, SkipType, TableSource, Unnest, }; use datafusion_expr::{ DistinctOn, DropView, Expr, LogicalPlan, LogicalPlanBuilder, ScalarUDF, SortExpr, @@ -1085,12 +1084,13 @@ impl AsLogicalPlan for LogicalPlanNode { ))? .try_into_logical_plan(ctx, extension_codec)?; - Ok(LogicalPlan::RecursiveQuery(RecursiveQuery { - name: recursive_query_node.name.clone(), - static_term: Arc::new(static_term), - recursive_term: Arc::new(recursive_term), - is_distinct: recursive_query_node.is_distinct, - })) + LogicalPlanBuilder::from(static_term) + .to_recursive_query( + recursive_query_node.name.clone(), + recursive_term, + recursive_query_node.is_distinct, + )? + .build() } LogicalPlanType::CteWorkTableScan(cte_work_table_scan_node) => { let CteWorkTableScanNode { name, schema } = cte_work_table_scan_node; diff --git a/datafusion/sql/src/cte.rs b/datafusion/sql/src/cte.rs index 18766d7056355..a08eed33f7c15 100644 --- a/datafusion/sql/src/cte.rs +++ b/datafusion/sql/src/cte.rs @@ -17,6 +17,8 @@ use std::sync::Arc; +use arrow::datatypes::{Schema, SchemaRef}; + use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_common::{ @@ -133,9 +135,10 @@ impl SqlToRel<'_, S> { // ---------- Step 2: Create a temporary relation ------------------ // Step 2.1: Create a table source for the temporary relation - let work_table_source = self - .context_provider - .create_cte_work_table(cte_name, Arc::clone(static_plan.schema().inner()))?; + let work_table_source = self.context_provider.create_cte_work_table( + cte_name, + nullable_schema(static_plan.schema().inner()), + )?; // Step 2.2: Create a temporary relation logical plan that will be used // as the input to the recursive term @@ -184,6 +187,17 @@ impl SqlToRel<'_, S> { } } +fn nullable_schema(schema: &SchemaRef) -> SchemaRef { + Arc::new(Schema::new_with_metadata( + schema + .fields() + .iter() + .map(|field| field.as_ref().clone().with_nullable(true)) + .collect::>(), + schema.metadata().clone(), + )) +} + fn has_work_table_reference( plan: &LogicalPlan, work_table_source: &Arc, diff --git a/datafusion/sqllogictest/test_files/cte.slt b/datafusion/sqllogictest/test_files/cte.slt index d13e0d4f085e9..26b9a44286b48 100644 --- a/datafusion/sqllogictest/test_files/cte.slt +++ b/datafusion/sqllogictest/test_files/cte.slt @@ -171,7 +171,7 @@ logical_plan 07)--------TableScan: nodes projection=[id] physical_plan 01)RecursiveQueryExec: name=nodes, is_distinct=false -02)--ProjectionExec: expr=[1 as id] +02)--ProjectionExec: expr=[CAST(1 AS Int64) as id] 03)----PlaceholderRowExec 04)--CoalescePartitionsExec 05)----ProjectionExec: expr=[id@0 + 1 as id] @@ -195,6 +195,21 @@ SELECT * FROM nodes 3 4 +# recursive self-reference must use conservative nullability even when the +# anchor term uses non-null literals. Otherwise optimizer nullability-based +# simplification can remove this semantically required IS NOT NULL guard. +query II rowsort +WITH RECURSIVE t(a, b) AS ( + SELECT 0 AS a, 0 AS b + UNION ALL + SELECT b AS a, CAST(NULL AS INT) AS b FROM t WHERE a IS NOT NULL +) +SELECT * FROM t +---- +0 0 +0 NULL +NULL NULL + # deduplicating recursive CTE with two variables works query II WITH RECURSIVE ranges AS ( @@ -1079,7 +1094,7 @@ logical_plan 07)--------TableScan: numbers projection=[n] physical_plan 01)RecursiveQueryExec: name=numbers, is_distinct=false -02)--ProjectionExec: expr=[1 as n] +02)--ProjectionExec: expr=[CAST(1 AS Int64) as n] 03)----PlaceholderRowExec 04)--CoalescePartitionsExec 05)----ProjectionExec: expr=[n@0 + 1 as n] @@ -1104,7 +1119,7 @@ logical_plan 07)--------TableScan: numbers projection=[n] physical_plan 01)RecursiveQueryExec: name=numbers, is_distinct=false -02)--ProjectionExec: expr=[1 as n] +02)--ProjectionExec: expr=[CAST(1 AS Int64) as n] 03)----PlaceholderRowExec 04)--CoalescePartitionsExec 05)----ProjectionExec: expr=[n@0 + 1 as n] @@ -1161,7 +1176,7 @@ logical_plan physical_plan 01)GlobalLimitExec: skip=0, fetch=5 02)--RecursiveQueryExec: name=r, is_distinct=false -03)----ProjectionExec: expr=[0 as k, 0 as v] +03)----ProjectionExec: expr=[CAST(0 AS Int64) as k, CAST(0 AS Int64) as v] 04)------PlaceholderRowExec 05)----SortExec: TopK(fetch=1), expr=[v@1 ASC NULLS LAST], preserve_partitioning=[false] 06)------WorkTableExec: name=r From 869e49a1702b62e0428c97846221c217f1b2bdd9 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Wed, 13 May 2026 18:15:41 +0800 Subject: [PATCH 18/24] feat: refactor recursive query handling and nullability management - Removed public RecursiveQuery.schema - Restored original public struct shape - Kept nullability handling internal: - Recursive builder coerces terms to conservative nullable schema via existing projection schema override - Optimizer child rewrites rebuild recursive query via builder - Aggregate planner reconciles nullability only for recursive-query inputs - Updated affected SLT explain output --- datafusion/core/src/physical_planner.rs | 50 +++++++++++++++++++ datafusion/expr/src/logical_plan/builder.rs | 26 ++++++++-- datafusion/expr/src/logical_plan/plan.rs | 45 ++++------------- datafusion/expr/src/logical_plan/tree_node.rs | 2 - datafusion/sqllogictest/test_files/cte.slt | 5 +- 5 files changed, 84 insertions(+), 44 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 3b2c7a78e898e..696e47da77918 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -115,6 +115,42 @@ use itertools::{Itertools, multiunzip}; use log::debug; use tokio::sync::Mutex; +fn contains_recursive_query(plan: &LogicalPlan) -> bool { + let mut found = false; + let _ = plan.apply(|node| { + if matches!(node, LogicalPlan::RecursiveQuery(_)) { + found = true; + Ok(TreeNodeRecursion::Stop) + } else { + Ok(TreeNodeRecursion::Continue) + } + }); + found +} + +fn reconcile_logical_schema_nullability( + logical_schema: &DFSchema, + physical_schema: &Schema, +) -> Result { + let fields = logical_schema + .iter() + .zip(physical_schema.fields()) + .map(|((qualifier, logical_field), physical_field)| { + let field = logical_field + .as_ref() + .clone() + .with_nullable( + logical_field.is_nullable() || physical_field.is_nullable(), + ) + .into(); + (qualifier.cloned(), field) + }) + .collect::>(); + + DFSchema::new_with_metadata(fields, logical_schema.metadata().clone())? + .with_functional_dependencies(logical_schema.functional_dependencies().clone()) +} + /// Physical query planner that converts a `LogicalPlan` to an /// `ExecutionPlan` suitable for execution. #[async_trait] @@ -987,6 +1023,20 @@ impl DefaultPhysicalPlanner { let input_exec = children.one()?; let physical_input_schema = input_exec.schema(); let logical_input_schema = input.as_ref().schema(); + let reconciled_logical_schema; + let logical_input_schema = if schema_satisfied_by( + logical_input_schema.inner(), + &physical_input_schema, + ) || !contains_recursive_query(input) + { + logical_input_schema + } else { + reconciled_logical_schema = reconcile_logical_schema_nullability( + logical_input_schema, + &physical_input_schema, + )?; + &reconciled_logical_schema + }; let physical_input_schema_from_logical = logical_input_schema.inner(); if !options.execution.skip_physical_aggregate_schema_check diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index c07959ea7937d..3c2f9eae34274 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -90,6 +90,19 @@ fn recursive_query_output_schema( )?)) } +fn plan_with_schema(plan: LogicalPlan, schema: DFSchemaRef) -> Result { + match plan { + LogicalPlan::Projection(Projection { expr, input, .. }) => { + Projection::try_new_with_schema(expr, input, schema) + .map(LogicalPlan::Projection) + } + _ => Ok(LogicalPlan::Projection(Projection::new_from_schema( + Arc::new(plan), + schema, + ))), + } +} + /// Options for [`LogicalPlanBuilder`] #[derive(Default, Debug, Clone)] pub struct LogicalPlanBuilderOptions { @@ -220,16 +233,19 @@ impl LogicalPlanBuilder { self.plan.schema(), coerced_recursive_term.schema(), )?; - let static_term = - coerce_plan_expr_for_schema(Arc::unwrap_or_clone(self.plan), &output_schema)?; - let recursive_term = - coerce_plan_expr_for_schema(coerced_recursive_term, &output_schema)?; + let static_term = plan_with_schema( + coerce_plan_expr_for_schema(Arc::unwrap_or_clone(self.plan), &output_schema)?, + Arc::clone(&output_schema), + )?; + let recursive_term = plan_with_schema( + coerce_plan_expr_for_schema(coerced_recursive_term, &output_schema)?, + output_schema, + )?; Ok(Self::from(LogicalPlan::RecursiveQuery(RecursiveQuery { name, static_term: Arc::new(static_term), recursive_term: Arc::new(recursive_term), is_distinct, - schema: output_schema, }))) } diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 4b6328424e350..56a0bd5570941 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -353,7 +353,11 @@ impl LogicalPlan { LogicalPlan::Copy(CopyTo { output_schema, .. }) => output_schema, LogicalPlan::Ddl(ddl) => ddl.schema(), LogicalPlan::Unnest(Unnest { schema, .. }) => schema, - LogicalPlan::RecursiveQuery(RecursiveQuery { schema, .. }) => schema, + LogicalPlan::RecursiveQuery(RecursiveQuery { static_term, .. }) => { + // The static term is coerced to the recursive output schema when + // building a RecursiveQuery. + static_term.schema() + } } } @@ -1073,20 +1077,13 @@ impl LogicalPlan { Ok(LogicalPlan::Distinct(distinct)) } LogicalPlan::RecursiveQuery(RecursiveQuery { - name, - is_distinct, - schema, - .. + name, is_distinct, .. }) => { self.assert_no_expressions(expr)?; let (static_term, recursive_term) = self.only_two_inputs(inputs)?; - Ok(LogicalPlan::RecursiveQuery(RecursiveQuery { - name: name.clone(), - static_term: Arc::new(static_term), - recursive_term: Arc::new(recursive_term), - is_distinct: *is_distinct, - schema: Arc::clone(schema), - })) + LogicalPlanBuilder::from(static_term) + .to_recursive_query(name.clone(), recursive_term, *is_distinct)? + .build() } LogicalPlan::Analyze(a) => { self.assert_no_expressions(expr)?; @@ -2247,7 +2244,7 @@ impl PartialOrd for EmptyRelation { /// intermediate table, then empty the intermediate table. /// /// [Postgres Docs]: https://www.postgresql.org/docs/current/queries-with.html#QUERIES-WITH-RECURSIVE -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub struct RecursiveQuery { /// Name of the query pub name: String, @@ -2259,28 +2256,6 @@ pub struct RecursiveQuery { /// Should the output of the recursive term be deduplicated (`UNION`) or /// not (`UNION ALL`). pub is_distinct: bool, - /// The output schema of the recursive query. - pub schema: DFSchemaRef, -} - -// Manual implementation needed because of `schema` field. Comparison excludes this field. -impl PartialOrd for RecursiveQuery { - fn partial_cmp(&self, other: &Self) -> Option { - self.name - .partial_cmp(&other.name) - .and_then(|ordering| match ordering { - Ordering::Equal => self.static_term.partial_cmp(&other.static_term), - _ => Some(ordering), - }) - .and_then(|ordering| match ordering { - Ordering::Equal => self.recursive_term.partial_cmp(&other.recursive_term), - _ => Some(ordering), - }) - .and_then(|ordering| match ordering { - Ordering::Equal => self.is_distinct.partial_cmp(&other.is_distinct), - _ => Some(ordering), - }) - } } /// Values expression. See diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index 073f7fe7537ae..ef9382a57209a 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -329,7 +329,6 @@ impl TreeNode for LogicalPlan { static_term, recursive_term, is_distinct, - schema, }) => (static_term, recursive_term).map_elements(f)?.update_data( |(static_term, recursive_term)| { LogicalPlan::RecursiveQuery(RecursiveQuery { @@ -337,7 +336,6 @@ impl TreeNode for LogicalPlan { static_term, recursive_term, is_distinct, - schema, }) }, ), diff --git a/datafusion/sqllogictest/test_files/cte.slt b/datafusion/sqllogictest/test_files/cte.slt index 26b9a44286b48..db42d0eac22ff 100644 --- a/datafusion/sqllogictest/test_files/cte.slt +++ b/datafusion/sqllogictest/test_files/cte.slt @@ -1171,8 +1171,9 @@ logical_plan 03)----RecursiveQuery: is_distinct=false 04)------Projection: Int64(0) AS k, Int64(0) AS v 05)--------EmptyRelation: rows=1 -06)------Sort: r.v ASC NULLS LAST, fetch=1 -07)--------TableScan: r projection=[k, v] +06)------Projection: k, v +07)--------Sort: r.v ASC NULLS LAST, fetch=1 +08)----------TableScan: r projection=[k, v] physical_plan 01)GlobalLimitExec: skip=0, fetch=5 02)--RecursiveQueryExec: name=r, is_distinct=false From 3389785557e7df6ba94add897824a170ab13b036 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Wed, 13 May 2026 18:31:17 +0800 Subject: [PATCH 19/24] fix: update explain_tree.slt to reflect correct type casting in projection --- datafusion/sqllogictest/test_files/explain_tree.slt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/sqllogictest/test_files/explain_tree.slt b/datafusion/sqllogictest/test_files/explain_tree.slt index 46d01f39a920b..1a2c4ebab437f 100644 --- a/datafusion/sqllogictest/test_files/explain_tree.slt +++ b/datafusion/sqllogictest/test_files/explain_tree.slt @@ -1582,7 +1582,7 @@ physical_plan 04)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ 05)│ ProjectionExec ││ CoalescePartitionsExec │ 06)│ -------------------- ││ │ -07)│ id: 1 ││ │ +07)│ id: CAST(1 AS Int64) ││ │ 08)└─────────────┬─────────────┘└─────────────┬─────────────┘ 09)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ 10)│ PlaceholderRowExec ││ ProjectionExec │ From 69530768fac09464bca98c5df05b7af48deada88 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Wed, 13 May 2026 20:37:18 +0800 Subject: [PATCH 20/24] fix: correct SUM(0) -> 0 as level in recursive CTE query --- datafusion/sqllogictest/test_files/cte.slt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/sqllogictest/test_files/cte.slt b/datafusion/sqllogictest/test_files/cte.slt index db42d0eac22ff..ad10244f7e62c 100644 --- a/datafusion/sqllogictest/test_files/cte.slt +++ b/datafusion/sqllogictest/test_files/cte.slt @@ -714,7 +714,7 @@ WITH RECURSIVE region_sales AS ( SELECT s.salesperson_id AS salesperson_id, SUM(s.sale_amount) AS amount, - SUM(0) as level + 0 as level FROM sales s GROUP BY From 4b6f9faa732a902a7605ac25c9dec57f3a26f8d9 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Wed, 13 May 2026 19:01:23 +0800 Subject: [PATCH 21/24] feat: update reconcile_recursive_query_input_nullability to handle only nullability widening - Only reconciles nullability widening; rejects mismatches in count, name, type, and field/schema metadata. - Removes zip truncation masking. - Renamed function contains_recursive_query_input for clarity. - Added comment to clarify aggregate recursive CTE special-case. - Updated plan_with_schema to use input schema expressions instead of target schema columns. - Introduced focused unit tests for validating allowed/rejected reconciliation cases. - Adjusted SLT explain to align with the new safer projection logic. --- datafusion/core/src/physical_planner.rs | 123 ++++++++++++++++---- datafusion/expr/src/logical_plan/builder.rs | 9 +- datafusion/sqllogictest/test_files/cte.slt | 5 +- 3 files changed, 107 insertions(+), 30 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 696e47da77918..fe64e814d47e1 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -115,7 +115,11 @@ use itertools::{Itertools, multiunzip}; use log::debug; use tokio::sync::Mutex; -fn contains_recursive_query(plan: &LogicalPlan) -> bool { +/// Aggregate planning normally verifies that the physical input schema satisfies +/// the logical input schema exactly. Recursive CTEs are an exception only for +/// nullability widening: logical planning may conservatively expose nullable +/// recursive output after the aggregate's logical input schema was derived. +fn contains_recursive_query_input(plan: &LogicalPlan) -> bool { let mut found = false; let _ = plan.apply(|node| { if matches!(node, LogicalPlan::RecursiveQuery(_)) { @@ -128,27 +132,45 @@ fn contains_recursive_query(plan: &LogicalPlan) -> bool { found } -fn reconcile_logical_schema_nullability( +fn reconcile_recursive_query_input_nullability( logical_schema: &DFSchema, physical_schema: &Schema, -) -> Result { - let fields = logical_schema - .iter() - .zip(physical_schema.fields()) - .map(|((qualifier, logical_field), physical_field)| { - let field = logical_field - .as_ref() - .clone() - .with_nullable( - logical_field.is_nullable() || physical_field.is_nullable(), - ) - .into(); - (qualifier.cloned(), field) - }) - .collect::>(); +) -> Result> { + if logical_schema.metadata() != physical_schema.metadata() + || logical_schema.fields().len() != physical_schema.fields().len() + { + return Ok(None); + } + + let mut widened_nullability = false; + let mut fields = Vec::with_capacity(logical_schema.fields().len()); + for ((qualifier, logical_field), physical_field) in + logical_schema.iter().zip(physical_schema.fields()) + { + if logical_field.name() != physical_field.name() + || logical_field.data_type() != physical_field.data_type() + || logical_field.metadata() != physical_field.metadata() + { + return Ok(None); + } + + widened_nullability |= + !logical_field.is_nullable() && physical_field.is_nullable(); + let field = logical_field + .as_ref() + .clone() + .with_nullable(logical_field.is_nullable() || physical_field.is_nullable()) + .into(); + fields.push((qualifier.cloned(), field)); + } + + if !widened_nullability { + return Ok(None); + } DFSchema::new_with_metadata(fields, logical_schema.metadata().clone())? .with_functional_dependencies(logical_schema.functional_dependencies().clone()) + .map(Some) } /// Physical query planner that converts a `LogicalPlan` to an @@ -1027,15 +1049,17 @@ impl DefaultPhysicalPlanner { let logical_input_schema = if schema_satisfied_by( logical_input_schema.inner(), &physical_input_schema, - ) || !contains_recursive_query(input) + ) || !contains_recursive_query_input(input) { logical_input_schema - } else { - reconciled_logical_schema = reconcile_logical_schema_nullability( - logical_input_schema, - &physical_input_schema, - )?; + } else if let Some(schema) = reconcile_recursive_query_input_nullability( + logical_input_schema, + &physical_input_schema, + )? { + reconciled_logical_schema = schema; &reconciled_logical_schema + } else { + logical_input_schema }; let physical_input_schema_from_logical = logical_input_schema.inner(); @@ -4896,6 +4920,59 @@ digraph { assert_contains!(err.to_string(), "field data type at index"); } + #[test] + fn recursive_query_input_nullability_reconciliation_only_widens_nullability() { + let logical_schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]) + .to_dfschema_ref() + .unwrap(); + let physical_schema = Schema::new(vec![Field::new("c1", DataType::Int32, true)]); + + let reconciled = reconcile_recursive_query_input_nullability( + &logical_schema, + &physical_schema, + ) + .unwrap() + .expect("nullability widening should reconcile"); + + assert!(reconciled.field(0).is_nullable()); + assert_eq!(reconciled.field(0).name(), "c1"); + assert_eq!(reconciled.field(0).data_type(), &DataType::Int32); + } + + #[test] + fn recursive_query_input_nullability_reconciliation_rejects_other_mismatches() { + let logical_schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]) + .to_dfschema_ref() + .unwrap(); + + let cases = [ + Schema::new(vec![ + Field::new("c1", DataType::Int32, true), + Field::new("c2", DataType::Int32, true), + ]), + Schema::new(vec![Field::new("different", DataType::Int32, true)]), + Schema::new(vec![Field::new("c1", DataType::Int64, true)]), + Schema::new(vec![ + Field::new("c1", DataType::Int32, true) + .with_metadata(HashMap::from([("key".into(), "value".into())])), + ]), + Schema::new(vec![Field::new("c1", DataType::Int32, true)]) + .with_metadata(HashMap::from([("key".into(), "value".into())])), + ]; + + for physical_schema in cases { + assert!( + reconcile_recursive_query_input_nullability( + &logical_schema, + &physical_schema, + ) + .unwrap() + .is_none(), + "should not reconcile unsupported mismatch: {physical_schema:?}" + ); + } + } + #[tokio::test] async fn test_aggregate_schema_mismatch_field_nullability() { let logical_schema = diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 3c2f9eae34274..3a71921a251e0 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -96,10 +96,11 @@ fn plan_with_schema(plan: LogicalPlan, schema: DFSchemaRef) -> Result Ok(LogicalPlan::Projection(Projection::new_from_schema( - Arc::new(plan), - schema, - ))), + _ => { + let exprs = plan.schema().iter().map(Expr::from).collect(); + Projection::try_new_with_schema(exprs, Arc::new(plan), schema) + .map(LogicalPlan::Projection) + } } } diff --git a/datafusion/sqllogictest/test_files/cte.slt b/datafusion/sqllogictest/test_files/cte.slt index ad10244f7e62c..a906063ace47a 100644 --- a/datafusion/sqllogictest/test_files/cte.slt +++ b/datafusion/sqllogictest/test_files/cte.slt @@ -1171,9 +1171,8 @@ logical_plan 03)----RecursiveQuery: is_distinct=false 04)------Projection: Int64(0) AS k, Int64(0) AS v 05)--------EmptyRelation: rows=1 -06)------Projection: k, v -07)--------Sort: r.v ASC NULLS LAST, fetch=1 -08)----------TableScan: r projection=[k, v] +06)------Sort: r.v ASC NULLS LAST, fetch=1 +07)--------TableScan: r projection=[k, v] physical_plan 01)GlobalLimitExec: skip=0, fetch=5 02)--RecursiveQueryExec: name=r, is_distinct=false From e76aa546e63c66ea302a6cf3e3a39775501ab5c9 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Wed, 13 May 2026 19:26:42 +0800 Subject: [PATCH 22/24] feat: implement central recursive CTE schema helpers - Added `datafusion/common/src/recursive_schema.rs` with the following functions: - `make_schema_nullable` - `recursive_query_output_schema` - `reconcile_dfschema_with_schema_nullability` - Tests for nullability widening and reject mismatches. - Integrated the new schema helpers into existing components: - Updated `sql/src/cte.rs` to use the common nullable work-table schema helper. - Updated `expr/src/logical_plan/builder.rs` to use the common recursive output schema helper. - Updated `core/src/physical_planner.rs` to use the common physical/logical reconciliation helper. - Removed duplicated local helpers and tests from `core`, `expr`, and `sql`. Semver: - No breaking field/signature changes; added a doc-hidden helper module only. --- datafusion/common/src/lib.rs | 2 + datafusion/common/src/recursive_schema.rs | 246 ++++++++++++++++++++ datafusion/core/src/physical_planner.rs | 97 +------- datafusion/expr/src/logical_plan/builder.rs | 25 +- datafusion/sql/src/cte.rs | 16 +- 5 files changed, 253 insertions(+), 133 deletions(-) create mode 100644 datafusion/common/src/recursive_schema.rs diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index 9be0941b5d575..3d494f41b3a6a 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -54,6 +54,8 @@ mod null_equality; pub mod parquet_config; pub mod parsers; pub mod pruning; +#[doc(hidden)] +pub mod recursive_schema; pub mod rounding; pub mod scalar; pub mod spans; diff --git a/datafusion/common/src/recursive_schema.rs b/datafusion/common/src/recursive_schema.rs new file mode 100644 index 0000000000000..b1b9b0d280a71 --- /dev/null +++ b/datafusion/common/src/recursive_schema.rs @@ -0,0 +1,246 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Internal helpers for recursive CTE schema reconciliation. +//! +//! Recursive CTE work-table references and children must expose schemas that +//! are conservative for nullability, while preserving every other schema +//! dimension exactly. + +use std::sync::Arc; + +use arrow::datatypes::{FieldRef, Schema, SchemaRef}; + +use crate::{DFSchema, DFSchemaRef, DataFusionError, Result}; + +/// Return an Arrow schema with all fields marked nullable, preserving field and +/// schema metadata. +#[doc(hidden)] +pub fn make_schema_nullable(schema: &Schema) -> SchemaRef { + Arc::new(Schema::new_with_metadata( + schema + .fields() + .iter() + .map(|field| field.as_ref().clone().with_nullable(true)) + .collect::>(), + schema.metadata().clone(), + )) +} + +/// Return a recursive query output schema that preserves `static_schema` except +/// for nullability widened by `recursive_schema`. +/// +/// This helper assumes recursive term expressions have already been coerced to +/// the static term's schema, and only reads field nullability from +/// `recursive_schema`. All other output schema dimensions come from +/// `static_schema`. +#[doc(hidden)] +pub fn recursive_query_output_schema( + static_schema: &DFSchema, + recursive_schema: &DFSchema, +) -> Result { + if static_schema.fields().len() != recursive_schema.fields().len() { + return Err(DataFusionError::Plan(format!( + "Non-recursive term and recursive term must have the same number of columns ({} != {})", + static_schema.fields().len(), + recursive_schema.fields().len() + ))); + } + + let fields = static_schema + .iter() + .zip(recursive_schema.fields()) + .map(|((qualifier, static_field), recursive_field)| { + let field = static_field + .as_ref() + .clone() + .with_nullable( + static_field.is_nullable() || recursive_field.is_nullable(), + ) + .into(); + (qualifier.cloned(), field) + }) + .collect::>(); + + DFSchema::new_with_metadata(fields, static_schema.metadata().clone())? + .with_functional_dependencies(static_schema.functional_dependencies().clone()) + .map(DFSchemaRef::new) +} + +/// Reconcile `logical_schema` with an Arrow schema, but only when the Arrow +/// schema differs by being more nullable. Returns `Ok(None)` if any other +/// schema dimension differs, so callers can report their normal schema error. +#[doc(hidden)] +pub fn reconcile_dfschema_with_schema_nullability( + logical_schema: &DFSchema, + physical_schema: &Schema, +) -> Result> { + if logical_schema.metadata() != physical_schema.metadata() + || logical_schema.fields().len() != physical_schema.fields().len() + { + return Ok(None); + } + + let physical_fields = physical_schema.fields().iter(); + widen_dfschema_nullability_with_fields(logical_schema, physical_fields) +} + +fn widen_dfschema_nullability_with_fields<'a>( + base_schema: &DFSchema, + widening_fields: impl Iterator, +) -> Result> { + let mut widened_nullability = false; + let mut fields = Vec::with_capacity(base_schema.fields().len()); + + for ((qualifier, base_field), widening_field) in + base_schema.iter().zip(widening_fields) + { + if base_field.name() != widening_field.name() + || base_field.data_type() != widening_field.data_type() + || base_field.metadata() != widening_field.metadata() + { + return Ok(None); + } + + widened_nullability |= !base_field.is_nullable() && widening_field.is_nullable(); + let field = base_field + .as_ref() + .clone() + .with_nullable(base_field.is_nullable() || widening_field.is_nullable()) + .into(); + fields.push((qualifier.cloned(), field)); + } + + if !widened_nullability { + return Ok(None); + } + + DFSchema::new_with_metadata(fields, base_schema.metadata().clone())? + .with_functional_dependencies(base_schema.functional_dependencies().clone()) + .map(Some) +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use arrow::datatypes::{DataType, Field, Schema}; + + use crate::ToDFSchema as _; + + use super::*; + + #[test] + fn make_schema_nullable_preserves_metadata() { + let schema = Schema::new_with_metadata( + vec![ + Field::new("c1", DataType::Int32, false) + .with_metadata(HashMap::from([("field".into(), "value".into())])), + ], + HashMap::from([("schema".into(), "value".into())]), + ); + + let nullable = make_schema_nullable(&schema); + + assert!(nullable.field(0).is_nullable()); + assert_eq!(nullable.field(0).metadata(), schema.field(0).metadata()); + assert_eq!(nullable.metadata(), schema.metadata()); + } + + #[test] + fn recursive_output_schema_preserves_static_dimensions_and_widens_nullability() { + let static_schema = Schema::new_with_metadata( + vec![ + Field::new("anchor_name", DataType::Int32, false) + .with_metadata(HashMap::from([("field".into(), "value".into())])), + ], + HashMap::from([("schema".into(), "value".into())]), + ) + .to_dfschema_ref() + .unwrap(); + let recursive_schema = Schema::new(vec![Field::new( + "recursive_expr_name", + DataType::Int32, + true, + )]) + .to_dfschema_ref() + .unwrap(); + + let output = + recursive_query_output_schema(&static_schema, &recursive_schema).unwrap(); + + assert_eq!(output.field(0).name(), "anchor_name"); + assert_eq!(output.field(0).data_type(), &DataType::Int32); + assert_eq!( + output.field(0).metadata(), + static_schema.field(0).metadata() + ); + assert_eq!(output.metadata(), static_schema.metadata()); + assert!(output.field(0).is_nullable()); + } + + #[test] + fn reconciliation_only_widens_nullability() { + let logical_schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]) + .to_dfschema_ref() + .unwrap(); + let physical_schema = Schema::new(vec![Field::new("c1", DataType::Int32, true)]); + + let reconciled = + reconcile_dfschema_with_schema_nullability(&logical_schema, &physical_schema) + .unwrap() + .expect("nullability widening should reconcile"); + + assert!(reconciled.field(0).is_nullable()); + assert_eq!(reconciled.field(0).name(), "c1"); + assert_eq!(reconciled.field(0).data_type(), &DataType::Int32); + } + + #[test] + fn reconciliation_rejects_other_mismatches() { + let logical_schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]) + .to_dfschema_ref() + .unwrap(); + + let cases = [ + Schema::new(vec![ + Field::new("c1", DataType::Int32, true), + Field::new("c2", DataType::Int32, true), + ]), + Schema::new(vec![Field::new("different", DataType::Int32, true)]), + Schema::new(vec![Field::new("c1", DataType::Int64, true)]), + Schema::new(vec![ + Field::new("c1", DataType::Int32, true) + .with_metadata(HashMap::from([("key".into(), "value".into())])), + ]), + Schema::new(vec![Field::new("c1", DataType::Int32, true)]) + .with_metadata(HashMap::from([("key".into(), "value".into())])), + ]; + + for physical_schema in cases { + assert!( + reconcile_dfschema_with_schema_nullability( + &logical_schema, + &physical_schema, + ) + .unwrap() + .is_none(), + "should not reconcile unsupported mismatch: {physical_schema:?}" + ); + } + } +} diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index fe64e814d47e1..a123b7708bccd 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -66,6 +66,7 @@ use datafusion_common::Column; use datafusion_common::HashMap as DFHashMap; use datafusion_common::display::ToStringifiedPlan; use datafusion_common::format::ExplainAnalyzeCategories; +use datafusion_common::recursive_schema::reconcile_dfschema_with_schema_nullability; use datafusion_common::tree_node::{ Transformed, TreeNode, TreeNodeRecursion, TreeNodeVisitor, }; @@ -132,47 +133,6 @@ fn contains_recursive_query_input(plan: &LogicalPlan) -> bool { found } -fn reconcile_recursive_query_input_nullability( - logical_schema: &DFSchema, - physical_schema: &Schema, -) -> Result> { - if logical_schema.metadata() != physical_schema.metadata() - || logical_schema.fields().len() != physical_schema.fields().len() - { - return Ok(None); - } - - let mut widened_nullability = false; - let mut fields = Vec::with_capacity(logical_schema.fields().len()); - for ((qualifier, logical_field), physical_field) in - logical_schema.iter().zip(physical_schema.fields()) - { - if logical_field.name() != physical_field.name() - || logical_field.data_type() != physical_field.data_type() - || logical_field.metadata() != physical_field.metadata() - { - return Ok(None); - } - - widened_nullability |= - !logical_field.is_nullable() && physical_field.is_nullable(); - let field = logical_field - .as_ref() - .clone() - .with_nullable(logical_field.is_nullable() || physical_field.is_nullable()) - .into(); - fields.push((qualifier.cloned(), field)); - } - - if !widened_nullability { - return Ok(None); - } - - DFSchema::new_with_metadata(fields, logical_schema.metadata().clone())? - .with_functional_dependencies(logical_schema.functional_dependencies().clone()) - .map(Some) -} - /// Physical query planner that converts a `LogicalPlan` to an /// `ExecutionPlan` suitable for execution. #[async_trait] @@ -1052,7 +1012,7 @@ impl DefaultPhysicalPlanner { ) || !contains_recursive_query_input(input) { logical_input_schema - } else if let Some(schema) = reconcile_recursive_query_input_nullability( + } else if let Some(schema) = reconcile_dfschema_with_schema_nullability( logical_input_schema, &physical_input_schema, )? { @@ -4920,59 +4880,6 @@ digraph { assert_contains!(err.to_string(), "field data type at index"); } - #[test] - fn recursive_query_input_nullability_reconciliation_only_widens_nullability() { - let logical_schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]) - .to_dfschema_ref() - .unwrap(); - let physical_schema = Schema::new(vec![Field::new("c1", DataType::Int32, true)]); - - let reconciled = reconcile_recursive_query_input_nullability( - &logical_schema, - &physical_schema, - ) - .unwrap() - .expect("nullability widening should reconcile"); - - assert!(reconciled.field(0).is_nullable()); - assert_eq!(reconciled.field(0).name(), "c1"); - assert_eq!(reconciled.field(0).data_type(), &DataType::Int32); - } - - #[test] - fn recursive_query_input_nullability_reconciliation_rejects_other_mismatches() { - let logical_schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]) - .to_dfschema_ref() - .unwrap(); - - let cases = [ - Schema::new(vec![ - Field::new("c1", DataType::Int32, true), - Field::new("c2", DataType::Int32, true), - ]), - Schema::new(vec![Field::new("different", DataType::Int32, true)]), - Schema::new(vec![Field::new("c1", DataType::Int64, true)]), - Schema::new(vec![ - Field::new("c1", DataType::Int32, true) - .with_metadata(HashMap::from([("key".into(), "value".into())])), - ]), - Schema::new(vec![Field::new("c1", DataType::Int32, true)]) - .with_metadata(HashMap::from([("key".into(), "value".into())])), - ]; - - for physical_schema in cases { - assert!( - reconcile_recursive_query_input_nullability( - &logical_schema, - &physical_schema, - ) - .unwrap() - .is_none(), - "should not reconcile unsupported mismatch: {physical_schema:?}" - ); - } - } - #[tokio::test] async fn test_aggregate_schema_mismatch_field_nullability() { let logical_schema = diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 3a71921a251e0..2ce2dcc9e2dfe 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -53,6 +53,7 @@ use arrow::datatypes::{DataType, Field, FieldRef, Fields, Schema, SchemaRef}; use datafusion_common::display::ToStringifiedPlan; use datafusion_common::file_options::file_type::FileType; use datafusion_common::metadata::FieldMetadata; +use datafusion_common::recursive_schema::recursive_query_output_schema; use datafusion_common::{ Column, Constraints, DFSchema, DFSchemaRef, NullEquality, Result, ScalarValue, TableReference, ToDFSchema, UnnestOptions, exec_err, @@ -66,30 +67,6 @@ use indexmap::IndexSet; /// Default table name for unnamed table pub const UNNAMED_TABLE: &str = "?table?"; -fn recursive_query_output_schema( - static_schema: &DFSchema, - recursive_schema: &DFSchema, -) -> Result { - let fields = static_schema - .iter() - .zip(recursive_schema.iter()) - .map(|((qualifier, static_field), (_, recursive_field))| { - let field = static_field - .as_ref() - .clone() - .with_nullable( - static_field.is_nullable() || recursive_field.is_nullable(), - ) - .into(); - (qualifier.cloned(), field) - }) - .collect::>(); - Ok(DFSchemaRef::new(DFSchema::new_with_metadata( - fields, - static_schema.metadata().clone(), - )?)) -} - fn plan_with_schema(plan: LogicalPlan, schema: DFSchemaRef) -> Result { match plan { LogicalPlan::Projection(Projection { expr, input, .. }) => { diff --git a/datafusion/sql/src/cte.rs b/datafusion/sql/src/cte.rs index a08eed33f7c15..06e209842eb16 100644 --- a/datafusion/sql/src/cte.rs +++ b/datafusion/sql/src/cte.rs @@ -17,12 +17,11 @@ use std::sync::Arc; -use arrow::datatypes::{Schema, SchemaRef}; - use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_common::{ Result, not_impl_err, plan_err, + recursive_schema::make_schema_nullable, tree_node::{TreeNode, TreeNodeRecursion}, }; use datafusion_expr::{LogicalPlan, LogicalPlanBuilder, TableSource}; @@ -137,7 +136,7 @@ impl SqlToRel<'_, S> { // Step 2.1: Create a table source for the temporary relation let work_table_source = self.context_provider.create_cte_work_table( cte_name, - nullable_schema(static_plan.schema().inner()), + make_schema_nullable(static_plan.schema().inner()), )?; // Step 2.2: Create a temporary relation logical plan that will be used @@ -187,17 +186,6 @@ impl SqlToRel<'_, S> { } } -fn nullable_schema(schema: &SchemaRef) -> SchemaRef { - Arc::new(Schema::new_with_metadata( - schema - .fields() - .iter() - .map(|field| field.as_ref().clone().with_nullable(true)) - .collect::>(), - schema.metadata().clone(), - )) -} - fn has_work_table_reference( plan: &LogicalPlan, work_table_source: &Arc, From 18b06b0a0aaa00e58e497018878f4f59eaef2ea1 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Wed, 13 May 2026 20:46:33 +0800 Subject: [PATCH 23/24] fix: update type casting in projection for explain_analyze test - Changed projection expressions in the `parquet_recursive_projection_pushdown` test to use `CAST` for consistency and improved type safety. --- datafusion/core/tests/sql/explain_analyze.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/core/tests/sql/explain_analyze.rs b/datafusion/core/tests/sql/explain_analyze.rs index b093563d9adda..fbbbee5a31e27 100644 --- a/datafusion/core/tests/sql/explain_analyze.rs +++ b/datafusion/core/tests/sql/explain_analyze.rs @@ -1014,7 +1014,7 @@ async fn parquet_recursive_projection_pushdown() -> Result<()> { SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] RecursiveQueryExec: name=number_series, is_distinct=false CoalescePartitionsExec - ProjectionExec: expr=[id@0 as id, 1 as level] + ProjectionExec: expr=[CAST(id@0 AS Int64) as id, CAST(1 AS Int64) as level] FilterExec: id@0 = 1 RepartitionExec: partitioning=RoundRobinBatch(NUM_CORES), input_partitions=1 DataSourceExec: file_groups={1 group: [[TMP_DIR/hierarchy.parquet]]}, projection=[id], file_type=parquet, predicate=id@0 = 1, pruning_predicate=id_null_count@2 != row_count@3 AND id_min@0 <= 1 AND 1 <= id_max@1, required_guarantees=[id in (1)] From aa7fb40e8694f922ad35ac0af4b33fae95dbf83a Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Wed, 13 May 2026 21:12:29 +0800 Subject: [PATCH 24/24] feat: update TreeNode::exists usage and optimize CTE handling - Refactored TreeNode::exists in physical planner and CTE modules - Removed redundant recursive CTE re-coercion in logical plan builder - Inlined small one-use variables in recursive schema module --- datafusion/common/src/recursive_schema.rs | 38 ++++++++++++--------- datafusion/core/src/physical_planner.rs | 13 ++----- datafusion/expr/src/logical_plan/builder.rs | 7 ++-- datafusion/sql/src/cte.rs | 21 ++++-------- 4 files changed, 33 insertions(+), 46 deletions(-) diff --git a/datafusion/common/src/recursive_schema.rs b/datafusion/common/src/recursive_schema.rs index b1b9b0d280a71..e236de484aff8 100644 --- a/datafusion/common/src/recursive_schema.rs +++ b/datafusion/common/src/recursive_schema.rs @@ -65,14 +65,16 @@ pub fn recursive_query_output_schema( .iter() .zip(recursive_schema.fields()) .map(|((qualifier, static_field), recursive_field)| { - let field = static_field - .as_ref() - .clone() - .with_nullable( - static_field.is_nullable() || recursive_field.is_nullable(), - ) - .into(); - (qualifier.cloned(), field) + ( + qualifier.cloned(), + static_field + .as_ref() + .clone() + .with_nullable( + static_field.is_nullable() || recursive_field.is_nullable(), + ) + .into(), + ) }) .collect::>(); @@ -95,8 +97,10 @@ pub fn reconcile_dfschema_with_schema_nullability( return Ok(None); } - let physical_fields = physical_schema.fields().iter(); - widen_dfschema_nullability_with_fields(logical_schema, physical_fields) + widen_dfschema_nullability_with_fields( + logical_schema, + physical_schema.fields().iter(), + ) } fn widen_dfschema_nullability_with_fields<'a>( @@ -117,12 +121,14 @@ fn widen_dfschema_nullability_with_fields<'a>( } widened_nullability |= !base_field.is_nullable() && widening_field.is_nullable(); - let field = base_field - .as_ref() - .clone() - .with_nullable(base_field.is_nullable() || widening_field.is_nullable()) - .into(); - fields.push((qualifier.cloned(), field)); + fields.push(( + qualifier.cloned(), + base_field + .as_ref() + .clone() + .with_nullable(base_field.is_nullable() || widening_field.is_nullable()) + .into(), + )); } if !widened_nullability { diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index a123b7708bccd..07d91a407a0a5 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -121,16 +121,9 @@ use tokio::sync::Mutex; /// nullability widening: logical planning may conservatively expose nullable /// recursive output after the aggregate's logical input schema was derived. fn contains_recursive_query_input(plan: &LogicalPlan) -> bool { - let mut found = false; - let _ = plan.apply(|node| { - if matches!(node, LogicalPlan::RecursiveQuery(_)) { - found = true; - Ok(TreeNodeRecursion::Stop) - } else { - Ok(TreeNodeRecursion::Continue) - } - }); - found + plan.exists(|node| Ok(matches!(node, LogicalPlan::RecursiveQuery(_)))) + // Closure always returns Ok + .unwrap() } /// Physical query planner that converts a `LogicalPlan` to an diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 2ce2dcc9e2dfe..d99d4ea564cd6 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -212,13 +212,10 @@ impl LogicalPlanBuilder { coerced_recursive_term.schema(), )?; let static_term = plan_with_schema( - coerce_plan_expr_for_schema(Arc::unwrap_or_clone(self.plan), &output_schema)?, + Arc::unwrap_or_clone(self.plan), Arc::clone(&output_schema), )?; - let recursive_term = plan_with_schema( - coerce_plan_expr_for_schema(coerced_recursive_term, &output_schema)?, - output_schema, - )?; + let recursive_term = plan_with_schema(coerced_recursive_term, output_schema)?; Ok(Self::from(LogicalPlan::RecursiveQuery(RecursiveQuery { name, static_term: Arc::new(static_term), diff --git a/datafusion/sql/src/cte.rs b/datafusion/sql/src/cte.rs index 06e209842eb16..24e31d6ba5d03 100644 --- a/datafusion/sql/src/cte.rs +++ b/datafusion/sql/src/cte.rs @@ -20,9 +20,8 @@ use std::sync::Arc; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_common::{ - Result, not_impl_err, plan_err, - recursive_schema::make_schema_nullable, - tree_node::{TreeNode, TreeNodeRecursion}, + Result, not_impl_err, plan_err, recursive_schema::make_schema_nullable, + tree_node::TreeNode, }; use datafusion_expr::{LogicalPlan, LogicalPlanBuilder, TableSource}; use sqlparser::ast::{Query, SetExpr, SetOperator, With}; @@ -190,17 +189,9 @@ fn has_work_table_reference( plan: &LogicalPlan, work_table_source: &Arc, ) -> bool { - let mut has_reference = false; - plan.apply(|node| { - if let LogicalPlan::TableScan(scan) = node - && Arc::ptr_eq(&scan.source, work_table_source) - { - has_reference = true; - return Ok(TreeNodeRecursion::Stop); - } - Ok(TreeNodeRecursion::Continue) + plan.exists(|node| { + Ok(matches!(node, LogicalPlan::TableScan(scan) if Arc::ptr_eq(&scan.source, work_table_source))) }) - // Closure always return Ok - .unwrap(); - has_reference + // Closure always returns Ok + .unwrap() }