Skip to content

Commit 12b31a2

Browse files
ECP
1 parent df00390 commit 12b31a2

2 files changed

Lines changed: 115 additions & 92 deletions

File tree

src/uipath_langchain/agent/react/init_node.py

Lines changed: 21 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -22,34 +22,31 @@ async def _build_schema_context(entities: list) -> str:
2222
format_schemas_for_context,
2323
)
2424

25+
ecp_enabled = False
2526
try:
2627
from uipath.core.feature_flags import FeatureFlags
2728

28-
flag_value = FeatureFlags.is_flag_enabled("EnableEntityContextPackEnrichment")
29-
with open("/tmp/init_node_debug.log", "a") as _dbg:
30-
_dbg.write(f"[ECP] flag value: {flag_value}\n")
31-
_dbg.write(f"[ECP] all flags: {FeatureFlags._flags if hasattr(FeatureFlags, '_flags') else 'no _flags attr'}\n")
29+
ecp_enabled = FeatureFlags.is_flag_enabled(
30+
"EnableEntityContextPackEnrichment"
31+
)
32+
except Exception:
33+
logger.info("Feature flags unavailable, using basic schema context")
3234

33-
if flag_value:
35+
if ecp_enabled:
36+
try:
3437
from uipath_langchain.agent.tools.datafabric_tool import (
3538
build_entity_context_packs,
3639
format_ecp_for_context,
3740
)
3841

39-
with open("/tmp/init_node_debug.log", "a") as _dbg:
40-
_dbg.write("[ECP] Building enriched ECPs\n")
42+
logger.info("Building enriched Entity Context Packs")
4143
context_packs = await build_entity_context_packs(entities)
42-
with open("/tmp/init_node_ecp.json", "w") as _ef:
43-
import json
44-
_ef.write(json.dumps([p.to_dict() for p in context_packs], indent=2, default=str))
4544
return format_ecp_for_context(context_packs)
46-
except Exception as e:
47-
with open("/tmp/init_node_debug.log", "a") as _dbg:
48-
_dbg.write(f"[ECP] EXCEPTION: {type(e).__name__}: {e}\n")
49-
logger.warning(
50-
"ECP enrichment failed, falling back to basic schema",
51-
exc_info=True,
52-
)
45+
except Exception:
46+
logger.warning(
47+
"ECP enrichment failed, falling back to basic schema",
48+
exc_info=True,
49+
)
5350

5451
return format_schemas_for_context(entities)
5552

@@ -64,9 +61,6 @@ def create_init_node(
6461
async def graph_state_init(state: Any) -> Any:
6562
# --- Data Fabric schema fetch (INIT-time) ---
6663
schema_context: str | None = None
67-
# Debug: write to file since robot swallows stdout/stderr
68-
with open("/tmp/init_node_debug.log", "a") as _dbg:
69-
_dbg.write(f"[INIT_NODE] resources_for_init present: {resources_for_init is not None}\n")
7064
if resources_for_init:
7165
from uipath_langchain.agent.tools.datafabric_tool import (
7266
fetch_entity_schemas,
@@ -76,27 +70,13 @@ async def graph_state_init(state: Any) -> Any:
7670
entity_identifiers = get_datafabric_entity_identifiers_from_resources(
7771
resources_for_init
7872
)
79-
with open("/tmp/init_node_debug.log", "a") as _dbg:
80-
_dbg.write(f"[INIT_NODE] entity_identifiers: {entity_identifiers}\n")
8173
if entity_identifiers:
8274
logger.info(
8375
"Fetching Data Fabric schemas for %d identifier(s)",
8476
len(entity_identifiers),
8577
)
8678
entities = await fetch_entity_schemas(entity_identifiers)
87-
with open("/tmp/init_node_debug.log", "a") as _dbg:
88-
_dbg.write(f"[INIT_NODE] fetched {len(entities)} entities\n")
8979
schema_context = await _build_schema_context(entities)
90-
with open("/tmp/init_node_debug.log", "a") as _dbg:
91-
_dbg.write(f"[INIT_NODE] schema_context length: {len(schema_context) if schema_context else 0}\n")
92-
with open("/tmp/init_node_schema.txt", "w") as _sf:
93-
_sf.write(schema_context or "")
94-
if schema_context:
95-
logger.info(
96-
"Schema context length: %d chars, starts with: %.200s",
97-
len(schema_context),
98-
schema_context,
99-
)
10080

10181
# --- Resolve messages ---
10282
resolved_messages: Sequence[SystemMessage | HumanMessage] | Overwrite
@@ -110,10 +90,15 @@ async def graph_state_init(state: Any) -> Any:
11090
else:
11191
resolved_messages = list(messages)
11292

113-
# Debug: dump the full system prompt the LLM will see
114-
_msgs = resolved_messages.value if isinstance(resolved_messages, Overwrite) else resolved_messages
93+
# Log the full system prompt for debugging
94+
_msgs = (
95+
resolved_messages.value
96+
if isinstance(resolved_messages, Overwrite)
97+
else resolved_messages
98+
)
11599
for _m in _msgs:
116100
if isinstance(_m, SystemMessage):
101+
logger.debug("Full system prompt:\n%s", _m.content)
117102
with open("/tmp/init_node_full_system_prompt.txt", "w") as _fp:
118103
_fp.write(str(_m.content))
119104
break

src/uipath_langchain/agent/tools/datafabric_tool/entity_context_pack.py

Lines changed: 94 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Entity Context Pack - Rich metadata for prod Text2SQL optimization.
1+
"""Entity Context Pack — rich metadata for Text2SQL optimization.
22
33
Builds ECPs from Data Fabric entity metadata + sample data at INIT time.
44
No LLM generation — synonyms from field.description, samples from DF API,
@@ -10,16 +10,12 @@
1010
import logging
1111
import re
1212
from dataclasses import dataclass, field
13-
from functools import lru_cache
14-
from pathlib import Path
1513
from typing import Any
1614

17-
from uipath.platform.entities import Entity, FieldMetadata
15+
from uipath.platform.entities import Entity
1816

1917
logger = logging.getLogger(__name__)
2018

21-
_PROMPTS_DIR = Path(__file__).parent
22-
2319
# --- Type classification sets ---
2420

2521
_NUMERIC_TYPES = frozenset({
@@ -38,6 +34,7 @@
3834

3935

4036
# --- Dataclasses ---
37+
# to_dict() methods use sparse serialization (omit falsy fields) to save tokens.
4138

4239

4340
@dataclass
@@ -57,11 +54,7 @@ class ColumnContext:
5754
reference_entity: str | None = None
5855

5956
def to_dict(self) -> dict[str, Any]:
60-
"""Serialize to JSON-compatible dict."""
61-
d: dict[str, Any] = {
62-
"name": self.name,
63-
"type": self.type,
64-
}
57+
d: dict[str, Any] = {"name": self.name, "type": self.type}
6558
if self.description:
6659
d["description"] = self.description
6760
if self.synonyms:
@@ -74,12 +67,76 @@ def to_dict(self) -> dict[str, Any]:
7467
d["is_foreign_key"] = True
7568
if self.reference_entity:
7669
d["reference_entity"] = self.reference_entity
77-
d["is_numeric"] = self.is_numeric
78-
d["is_temporal"] = self.is_temporal
79-
d["is_categorical"] = self.is_categorical
70+
if self.is_numeric:
71+
d["is_numeric"] = True
72+
if self.is_temporal:
73+
d["is_temporal"] = True
74+
if self.is_categorical:
75+
d["is_categorical"] = True
8076
return d
8177

8278

79+
@dataclass
80+
class QueryCapabilities:
81+
"""Structured SQL capabilities for LLM parsing.
82+
83+
Intentionally duplicates some sql_constraints.txt content in a
84+
machine-parseable format alongside the free-text rules.
85+
"""
86+
87+
allowed_clauses: list[str] = field(default_factory=lambda: [
88+
"SELECT", "WHERE", "GROUP BY", "HAVING", "ORDER BY",
89+
"LIMIT", "OFFSET", "DISTINCT", "LEFT JOIN",
90+
])
91+
allowed_aggregations: list[str] = field(default_factory=lambda: [
92+
"COUNT(column_name)", "SUM", "AVG", "MIN", "MAX",
93+
])
94+
allowed_expressions: list[str] = field(default_factory=lambda: [
95+
"CASE/WHEN", "CAST", "COALESCE", "NULLIF",
96+
"ROUND", "ABS", "LOWER", "UPPER", "TRIM",
97+
"arithmetic (+, -, *, /)", "string concat (||)",
98+
])
99+
allowed_predicates: list[str] = field(default_factory=lambda: [
100+
"=", "<>", ">", "<", ">=", "<=",
101+
"BETWEEN", "IN", "LIKE", "IS NULL", "IS NOT NULL",
102+
"AND", "OR",
103+
])
104+
disallowed: list[str] = field(default_factory=lambda: [
105+
"SELECT *",
106+
"COUNT(*) — use COUNT(column_name)",
107+
"COUNT(DISTINCT ...) — no DISTINCT in aggregates",
108+
"subqueries in any clause",
109+
"UNION / INTERSECT / EXCEPT",
110+
"CTE (WITH clause)",
111+
"window functions (ROW_NUMBER, RANK, PARTITION BY)",
112+
"RIGHT JOIN / FULL OUTER JOIN / CROSS JOIN",
113+
"self-joins",
114+
"more than 4 tables in JOIN chain",
115+
"INSERT / UPDATE / DELETE / DDL",
116+
"ORDER BY columns not in SELECT",
117+
"HAVING without GROUP BY",
118+
"OFFSET without LIMIT",
119+
])
120+
critical_rules: list[str] = field(default_factory=lambda: [
121+
"ALWAYS use explicit column names — never SELECT *",
122+
"Use COUNT(column_name) — never COUNT(*) or COUNT(1)",
123+
"LIMIT is REQUIRED on every query without a WHERE clause",
124+
"All non-aggregated columns in SELECT must appear in GROUP BY",
125+
"Only LEFT JOIN is supported",
126+
"Maximum 4 tables in a JOIN chain",
127+
])
128+
129+
def to_dict(self) -> dict[str, Any]:
130+
return {
131+
"allowed_clauses": self.allowed_clauses,
132+
"allowed_aggregations": self.allowed_aggregations,
133+
"allowed_expressions": self.allowed_expressions,
134+
"allowed_predicates": self.allowed_predicates,
135+
"disallowed": self.disallowed,
136+
"critical_rules": self.critical_rules,
137+
}
138+
139+
83140
@dataclass
84141
class EntityContextPack:
85142
"""Complete context for a single entity."""
@@ -89,9 +146,9 @@ class EntityContextPack:
89146
description: str | None = None
90147
columns: list[ColumnContext] = field(default_factory=list)
91148
row_count: int | None = None
149+
query_capabilities: QueryCapabilities = field(default_factory=QueryCapabilities)
92150

93151
def to_dict(self) -> dict[str, Any]:
94-
"""Serialize to JSON-compatible dict."""
95152
d: dict[str, Any] = {
96153
"entity_name": self.entity_name,
97154
"display_name": self.display_name,
@@ -101,6 +158,7 @@ def to_dict(self) -> dict[str, Any]:
101158
if self.row_count is not None:
102159
d["row_count"] = self.row_count
103160
d["columns"] = [c.to_dict() for c in self.columns]
161+
d["query_capabilities"] = self.query_capabilities.to_dict()
104162
return d
105163

106164

@@ -110,11 +168,7 @@ def to_dict(self) -> dict[str, Any]:
110168
def classify_field_type(sql_type_name: str) -> tuple[bool, bool, bool]:
111169
"""Classify a SQL type into (is_numeric, is_temporal, is_categorical)."""
112170
t = sql_type_name.lower().strip()
113-
return (
114-
t in _NUMERIC_TYPES,
115-
t in _TEMPORAL_TYPES,
116-
t in _CATEGORICAL_TYPES,
117-
)
171+
return (t in _NUMERIC_TYPES, t in _TEMPORAL_TYPES, t in _CATEGORICAL_TYPES)
118172

119173

120174
def extract_synonyms(field_name: str, description: str | None) -> list[str]:
@@ -129,18 +183,17 @@ def extract_synonyms(field_name: str, description: str | None) -> list[str]:
129183
name_lower = field_name.lower()
130184
synonyms: set[str] = set()
131185

132-
# Extract parenthetical content: "Total enrollment (K-12 students)"
133186
parens = re.findall(r"\(([^)]+)\)", description)
134187
for p in parens:
135188
p_stripped = p.strip()
136189
if p_stripped and p_stripped.lower() != name_lower:
137190
synonyms.add(p_stripped)
138191

139-
# Split on delimiters
140-
parts = re.split(r"[,;]|\bor\b|\baka\b|\balso known as\b", description, flags=re.IGNORECASE)
192+
parts = re.split(
193+
r"[,;]|\bor\b|\baka\b|\balso known as\b", description, flags=re.IGNORECASE
194+
)
141195
for part in parts:
142196
token = part.strip().strip(".")
143-
# Only keep short phrases (likely synonyms, not full sentences)
144197
if (
145198
token
146199
and len(token.split()) <= 4
@@ -152,7 +205,7 @@ def extract_synonyms(field_name: str, description: str | None) -> list[str]:
152205
return sorted(synonyms)
153206

154207

155-
async def fetch_sample_rows(
208+
async def _fetch_sample_rows(
156209
entity_key: str, limit: int = 5
157210
) -> list[dict[str, Any]]:
158211
"""Fetch sample rows from Data Fabric using list_records API."""
@@ -162,8 +215,8 @@ async def fetch_sample_rows(
162215
try:
163216
records = await sdk.entities.list_records_async(entity_key, limit=limit)
164217
return [record.model_dump(exclude={"id"}) for record in records]
165-
except Exception as e:
166-
logger.warning(f"Failed to fetch sample rows for '{entity_key}': {e}")
218+
except Exception:
219+
logger.warning("Failed to fetch sample rows for '%s'", entity_key, exc_info=True)
167220
return []
168221

169222

@@ -186,10 +239,12 @@ def _extract_column_examples(
186239
return examples
187240

188241

242+
# --- Builders ---
243+
244+
189245
async def build_entity_context_pack(entity: Entity) -> EntityContextPack:
190246
"""Build a full ECP from an Entity, including sample data from DF API."""
191-
# Fetch sample rows concurrently with building column metadata
192-
sample_rows = await fetch_sample_rows(entity.id)
247+
sample_rows = await _fetch_sample_rows(entity.id)
193248

194249
columns: list[ColumnContext] = []
195250
for f in entity.fields or []:
@@ -236,9 +291,7 @@ async def build_entity_context_packs(
236291
packs: list[EntityContextPack] = []
237292
for i, result in enumerate(results):
238293
if isinstance(result, Exception):
239-
logger.warning(
240-
f"Failed to build ECP for '{entities[i].name}': {result}"
241-
)
294+
logger.warning("Failed to build ECP for '%s': %s", entities[i].name, result)
242295
else:
243296
packs.append(result)
244297
return packs
@@ -247,46 +300,31 @@ async def build_entity_context_packs(
247300
# --- Formatting ---
248301

249302

250-
@lru_cache(maxsize=1)
251-
def _load_sql_constraints() -> str:
252-
"""Load SQL constraints from sql_constraints.txt."""
253-
constraints_path = _PROMPTS_DIR / "sql_constraints.txt"
254-
try:
255-
return constraints_path.read_text(encoding="utf-8")
256-
except FileNotFoundError:
257-
logger.warning(f"SQL constraints file not found: {constraints_path}")
258-
return ""
259-
260-
261303
def format_ecp_for_context(context_packs: list[EntityContextPack]) -> str:
262304
"""Format ECPs as JSON for injection into agent system prompt.
263305
264-
Produces: SQL constraints + ECP JSON block.
265-
The system_prompt.txt (SQL expert guidelines) is NOT included here —
266-
it goes into the Studio Web system message at design time.
306+
Produces: SQL generation guidelines + SQL constraints + ECP JSON block.
267307
"""
268308
if not context_packs:
269309
return ""
270310

311+
from .datafabric_tool import _load_sql_constraints, _load_system_prompt
312+
271313
lines: list[str] = []
272314

315+
system_prompt = _load_system_prompt()
316+
if system_prompt:
317+
lines.extend(["## SQL Query Generation Guidelines", "", system_prompt, ""])
318+
273319
sql_constraints = _load_sql_constraints()
274320
if sql_constraints:
275-
lines.append("## SQL Constraints")
276-
lines.append("")
277-
lines.append(sql_constraints)
278-
lines.append("")
321+
lines.extend(["## SQL Constraints", "", sql_constraints, ""])
279322

280323
ecp_json = json.dumps(
281324
[pack.to_dict() for pack in context_packs],
282325
indent=2,
283326
default=str,
284327
)
285-
286-
lines.append("## Entity Context Packs")
287-
lines.append("")
288-
lines.append("```json")
289-
lines.append(ecp_json)
290-
lines.append("```")
328+
lines.extend(["## Entity Context Packs", "", "```json", ecp_json, "```"])
291329

292330
return "\n".join(lines)

0 commit comments

Comments
 (0)