1- """Entity Context Pack - Rich metadata for prod Text2SQL optimization.
1+ """Entity Context Pack — rich metadata for Text2SQL optimization.
22
33Builds ECPs from Data Fabric entity metadata + sample data at INIT time.
44No LLM generation — synonyms from field.description, samples from DF API,
1010import logging
1111import re
1212from dataclasses import dataclass , field
13- from functools import lru_cache
14- from pathlib import Path
1513from typing import Any
1614
17- from uipath .platform .entities import Entity , FieldMetadata
15+ from uipath .platform .entities import Entity
1816
1917logger = logging .getLogger (__name__ )
2018
21- _PROMPTS_DIR = Path (__file__ ).parent
22-
2319# --- Type classification sets ---
2420
2521_NUMERIC_TYPES = frozenset ({
3834
3935
4036# --- Dataclasses ---
37+ # to_dict() methods use sparse serialization (omit falsy fields) to save tokens.
4138
4239
4340@dataclass
@@ -57,11 +54,7 @@ class ColumnContext:
5754 reference_entity : str | None = None
5855
5956 def to_dict (self ) -> dict [str , Any ]:
60- """Serialize to JSON-compatible dict."""
61- d : dict [str , Any ] = {
62- "name" : self .name ,
63- "type" : self .type ,
64- }
57+ d : dict [str , Any ] = {"name" : self .name , "type" : self .type }
6558 if self .description :
6659 d ["description" ] = self .description
6760 if self .synonyms :
@@ -74,12 +67,76 @@ def to_dict(self) -> dict[str, Any]:
7467 d ["is_foreign_key" ] = True
7568 if self .reference_entity :
7669 d ["reference_entity" ] = self .reference_entity
77- d ["is_numeric" ] = self .is_numeric
78- d ["is_temporal" ] = self .is_temporal
79- d ["is_categorical" ] = self .is_categorical
70+ if self .is_numeric :
71+ d ["is_numeric" ] = True
72+ if self .is_temporal :
73+ d ["is_temporal" ] = True
74+ if self .is_categorical :
75+ d ["is_categorical" ] = True
8076 return d
8177
8278
79+ @dataclass
80+ class QueryCapabilities :
81+ """Structured SQL capabilities for LLM parsing.
82+
83+ Intentionally duplicates some sql_constraints.txt content in a
84+ machine-parseable format alongside the free-text rules.
85+ """
86+
87+ allowed_clauses : list [str ] = field (default_factory = lambda : [
88+ "SELECT" , "WHERE" , "GROUP BY" , "HAVING" , "ORDER BY" ,
89+ "LIMIT" , "OFFSET" , "DISTINCT" , "LEFT JOIN" ,
90+ ])
91+ allowed_aggregations : list [str ] = field (default_factory = lambda : [
92+ "COUNT(column_name)" , "SUM" , "AVG" , "MIN" , "MAX" ,
93+ ])
94+ allowed_expressions : list [str ] = field (default_factory = lambda : [
95+ "CASE/WHEN" , "CAST" , "COALESCE" , "NULLIF" ,
96+ "ROUND" , "ABS" , "LOWER" , "UPPER" , "TRIM" ,
97+ "arithmetic (+, -, *, /)" , "string concat (||)" ,
98+ ])
99+ allowed_predicates : list [str ] = field (default_factory = lambda : [
100+ "=" , "<>" , ">" , "<" , ">=" , "<=" ,
101+ "BETWEEN" , "IN" , "LIKE" , "IS NULL" , "IS NOT NULL" ,
102+ "AND" , "OR" ,
103+ ])
104+ disallowed : list [str ] = field (default_factory = lambda : [
105+ "SELECT *" ,
106+ "COUNT(*) — use COUNT(column_name)" ,
107+ "COUNT(DISTINCT ...) — no DISTINCT in aggregates" ,
108+ "subqueries in any clause" ,
109+ "UNION / INTERSECT / EXCEPT" ,
110+ "CTE (WITH clause)" ,
111+ "window functions (ROW_NUMBER, RANK, PARTITION BY)" ,
112+ "RIGHT JOIN / FULL OUTER JOIN / CROSS JOIN" ,
113+ "self-joins" ,
114+ "more than 4 tables in JOIN chain" ,
115+ "INSERT / UPDATE / DELETE / DDL" ,
116+ "ORDER BY columns not in SELECT" ,
117+ "HAVING without GROUP BY" ,
118+ "OFFSET without LIMIT" ,
119+ ])
120+ critical_rules : list [str ] = field (default_factory = lambda : [
121+ "ALWAYS use explicit column names — never SELECT *" ,
122+ "Use COUNT(column_name) — never COUNT(*) or COUNT(1)" ,
123+ "LIMIT is REQUIRED on every query without a WHERE clause" ,
124+ "All non-aggregated columns in SELECT must appear in GROUP BY" ,
125+ "Only LEFT JOIN is supported" ,
126+ "Maximum 4 tables in a JOIN chain" ,
127+ ])
128+
129+ def to_dict (self ) -> dict [str , Any ]:
130+ return {
131+ "allowed_clauses" : self .allowed_clauses ,
132+ "allowed_aggregations" : self .allowed_aggregations ,
133+ "allowed_expressions" : self .allowed_expressions ,
134+ "allowed_predicates" : self .allowed_predicates ,
135+ "disallowed" : self .disallowed ,
136+ "critical_rules" : self .critical_rules ,
137+ }
138+
139+
83140@dataclass
84141class EntityContextPack :
85142 """Complete context for a single entity."""
@@ -89,9 +146,9 @@ class EntityContextPack:
89146 description : str | None = None
90147 columns : list [ColumnContext ] = field (default_factory = list )
91148 row_count : int | None = None
149+ query_capabilities : QueryCapabilities = field (default_factory = QueryCapabilities )
92150
93151 def to_dict (self ) -> dict [str , Any ]:
94- """Serialize to JSON-compatible dict."""
95152 d : dict [str , Any ] = {
96153 "entity_name" : self .entity_name ,
97154 "display_name" : self .display_name ,
@@ -101,6 +158,7 @@ def to_dict(self) -> dict[str, Any]:
101158 if self .row_count is not None :
102159 d ["row_count" ] = self .row_count
103160 d ["columns" ] = [c .to_dict () for c in self .columns ]
161+ d ["query_capabilities" ] = self .query_capabilities .to_dict ()
104162 return d
105163
106164
@@ -110,11 +168,7 @@ def to_dict(self) -> dict[str, Any]:
110168def classify_field_type (sql_type_name : str ) -> tuple [bool , bool , bool ]:
111169 """Classify a SQL type into (is_numeric, is_temporal, is_categorical)."""
112170 t = sql_type_name .lower ().strip ()
113- return (
114- t in _NUMERIC_TYPES ,
115- t in _TEMPORAL_TYPES ,
116- t in _CATEGORICAL_TYPES ,
117- )
171+ return (t in _NUMERIC_TYPES , t in _TEMPORAL_TYPES , t in _CATEGORICAL_TYPES )
118172
119173
120174def extract_synonyms (field_name : str , description : str | None ) -> list [str ]:
@@ -129,18 +183,17 @@ def extract_synonyms(field_name: str, description: str | None) -> list[str]:
129183 name_lower = field_name .lower ()
130184 synonyms : set [str ] = set ()
131185
132- # Extract parenthetical content: "Total enrollment (K-12 students)"
133186 parens = re .findall (r"\(([^)]+)\)" , description )
134187 for p in parens :
135188 p_stripped = p .strip ()
136189 if p_stripped and p_stripped .lower () != name_lower :
137190 synonyms .add (p_stripped )
138191
139- # Split on delimiters
140- parts = re .split (r"[,;]|\bor\b|\baka\b|\balso known as\b" , description , flags = re .IGNORECASE )
192+ parts = re .split (
193+ r"[,;]|\bor\b|\baka\b|\balso known as\b" , description , flags = re .IGNORECASE
194+ )
141195 for part in parts :
142196 token = part .strip ().strip ("." )
143- # Only keep short phrases (likely synonyms, not full sentences)
144197 if (
145198 token
146199 and len (token .split ()) <= 4
@@ -152,7 +205,7 @@ def extract_synonyms(field_name: str, description: str | None) -> list[str]:
152205 return sorted (synonyms )
153206
154207
155- async def fetch_sample_rows (
208+ async def _fetch_sample_rows (
156209 entity_key : str , limit : int = 5
157210) -> list [dict [str , Any ]]:
158211 """Fetch sample rows from Data Fabric using list_records API."""
@@ -162,8 +215,8 @@ async def fetch_sample_rows(
162215 try :
163216 records = await sdk .entities .list_records_async (entity_key , limit = limit )
164217 return [record .model_dump (exclude = {"id" }) for record in records ]
165- except Exception as e :
166- logger .warning (f "Failed to fetch sample rows for '{ entity_key } ': { e } " )
218+ except Exception :
219+ logger .warning ("Failed to fetch sample rows for '%s'" , entity_key , exc_info = True )
167220 return []
168221
169222
@@ -186,10 +239,12 @@ def _extract_column_examples(
186239 return examples
187240
188241
242+ # --- Builders ---
243+
244+
189245async def build_entity_context_pack (entity : Entity ) -> EntityContextPack :
190246 """Build a full ECP from an Entity, including sample data from DF API."""
191- # Fetch sample rows concurrently with building column metadata
192- sample_rows = await fetch_sample_rows (entity .id )
247+ sample_rows = await _fetch_sample_rows (entity .id )
193248
194249 columns : list [ColumnContext ] = []
195250 for f in entity .fields or []:
@@ -236,9 +291,7 @@ async def build_entity_context_packs(
236291 packs : list [EntityContextPack ] = []
237292 for i , result in enumerate (results ):
238293 if isinstance (result , Exception ):
239- logger .warning (
240- f"Failed to build ECP for '{ entities [i ].name } ': { result } "
241- )
294+ logger .warning ("Failed to build ECP for '%s': %s" , entities [i ].name , result )
242295 else :
243296 packs .append (result )
244297 return packs
@@ -247,46 +300,31 @@ async def build_entity_context_packs(
247300# --- Formatting ---
248301
249302
250- @lru_cache (maxsize = 1 )
251- def _load_sql_constraints () -> str :
252- """Load SQL constraints from sql_constraints.txt."""
253- constraints_path = _PROMPTS_DIR / "sql_constraints.txt"
254- try :
255- return constraints_path .read_text (encoding = "utf-8" )
256- except FileNotFoundError :
257- logger .warning (f"SQL constraints file not found: { constraints_path } " )
258- return ""
259-
260-
261303def format_ecp_for_context (context_packs : list [EntityContextPack ]) -> str :
262304 """Format ECPs as JSON for injection into agent system prompt.
263305
264- Produces: SQL constraints + ECP JSON block.
265- The system_prompt.txt (SQL expert guidelines) is NOT included here —
266- it goes into the Studio Web system message at design time.
306+ Produces: SQL generation guidelines + SQL constraints + ECP JSON block.
267307 """
268308 if not context_packs :
269309 return ""
270310
311+ from .datafabric_tool import _load_sql_constraints , _load_system_prompt
312+
271313 lines : list [str ] = []
272314
315+ system_prompt = _load_system_prompt ()
316+ if system_prompt :
317+ lines .extend (["## SQL Query Generation Guidelines" , "" , system_prompt , "" ])
318+
273319 sql_constraints = _load_sql_constraints ()
274320 if sql_constraints :
275- lines .append ("## SQL Constraints" )
276- lines .append ("" )
277- lines .append (sql_constraints )
278- lines .append ("" )
321+ lines .extend (["## SQL Constraints" , "" , sql_constraints , "" ])
279322
280323 ecp_json = json .dumps (
281324 [pack .to_dict () for pack in context_packs ],
282325 indent = 2 ,
283326 default = str ,
284327 )
285-
286- lines .append ("## Entity Context Packs" )
287- lines .append ("" )
288- lines .append ("```json" )
289- lines .append (ecp_json )
290- lines .append ("```" )
328+ lines .extend (["## Entity Context Packs" , "" , "```json" , ecp_json , "```" ])
291329
292330 return "\n " .join (lines )
0 commit comments