diff --git a/.env.example b/.env.example index f8680a7..7cba633 100644 --- a/.env.example +++ b/.env.example @@ -1,7 +1,7 @@ # Azure OpenAI Configuration AZURE_OPENAI_ENDPOINT=https://your-openai-resource.openai.azure.com/ AZURE_OPENAI_DEPLOYMENT=gpt-5-mini -AZURE_OPENAI_EMBEDDING_DEPLOYMENT=text-embedding-ada-002 +AZURE_OPENAI_EMBEDDING_DEPLOYMENT=text-embedding-3-small AZURE_TENANT_ID=your-tenant-id # MCP Server Configuration diff --git a/README.md b/README.md index a505c73..ad8a31b 100644 --- a/README.md +++ b/README.md @@ -9,351 +9,302 @@ products: - langchain urlFragment: langchain-agent-mcp name: LangChain Python Agent with Model Context Protocol (MCP) -description: A production-ready LangChain agent in Python using Azure OpenAI Responses API with MCP server integration, deployed on Azure Container Apps +description: A LangChain agent in Python that uses the Azure OpenAI Responses API and Model Context Protocol, deployed to Azure Container Apps with one command. --- # LangChain Agent with Model Context Protocol (MCP) +A two-service Python sample that shows how to wire a [LangChain](https://python.langchain.com/) agent to a [Model Context Protocol](https://modelcontextprotocol.io/) server, run them on [Azure Container Apps](https://learn.microsoft.com/azure/container-apps/), and back them with [Azure OpenAI](https://learn.microsoft.com/azure/ai-services/openai/) and Postgres + [pgvector](https://github.com/pgvector/pgvector). Use it as a reference for building your own agent + tool-server architecture on Azure. + ![LangChain MCP Agent](images/app-image.png) [![Open in GitHub Codespaces](https://github.com/codespaces/badge.svg)](https://codespaces.new/Azure-Samples/langchain-agent-python) -This sample demonstrates a **production-ready LangChain agent** that uses the **OpenAI Responses API** with **Model Context Protocol (MCP)** for tool integration. The agent uses **Azure OpenAI GPT-5-mini** with **Entra ID authentication**, **PostgreSQL with pgvector** for semantic search, and is deployed as microservices on **Azure Container Apps**. - -This is a simplified, Python version inspired by the [Microsoft AI Tour WRK540 workshop](https://github.com/microsoft/aitour26-WRK540-unlock-your-agents-potential-with-model-context-protocol), but uses the same product data and instructions. +## What you'll learn -## Features - -**LangChain with Responses API** - Uses OpenAI's latest Responses API for native MCP tool support -**Azure OpenAI GPT-5-mini** - Latest reasoning model deployed via Azure -**PostgreSQL with pgvector** - Semantic search over product catalog using vector embeddings -**Entra ID Authentication** - Keyless authentication using Managed Identity (no API keys) -**MCP Server** - FastMCP server with database and semantic search tools -**Microservices Architecture** - Agent and MCP server deployed as independent container apps -**Infrastructure as Code** - Complete Bicep templates with Azure best practices -**One-command Deployment** - Deploy everything with `azd up` +- How to call the **Azure OpenAI Responses API** from LangChain, including hosted server-side tools (`code_interpreter`, `web_search_preview`). +- How to expose database operations as **MCP tools** with FastMCP and connect them to the agent over streamable HTTP. +- How to use **Entra ID (Managed Identity)** for keyless auth to Azure OpenAI and Postgres. +- How to provision the whole stack — Container Apps, Azure OpenAI, Postgres Flexible Server, monitoring — with **`azd up`**. ## Architecture -```markdown +Two services, deployed independently as Container Apps: + +```text ┌─────────────────────────────────────────────────────────────┐ │ Azure Cloud │ -│ │ │ ┌──────────────────────────────────────────────────────┐ │ │ │ Azure Container Apps Environment │ │ │ │ │ │ -│ │ ┌─────────────────┐ ┌──────────────────┐ │ │ -│ │ │ Agent Container │──────│ MCP Server │ │ │ -│ │ │ - LangChain │ HTTP │ - PostgreSQL │ │ │ -│ │ │ - Responses API │◄─────│ - Semantic │ │ │ -│ │ │ │ │ Search │ │ │ -│ │ └─────────┬────────┘ └────────┬─────────┘ │ │ -│ │ │ │ │ │ +│ │ ┌─────────────────┐ ┌──────────────────┐ │ │ +│ │ │ agent │──HTTP─│ mcp-server │ │ │ +│ │ │ LangChain + │ │ FastMCP + │ │ │ +│ │ │ Responses API │◄──────│ Postgres tools │ │ │ +│ │ └────────┬────────┘ └─────────┬────────┘ │ │ │ └────────────┼─────────────────────────┼───────────────┘ │ -│ │ Entra ID │ │ -│ ▼ ▼ │ -│ ┌─────────────────────────┐ ┌──────────────────────┐ │ -│ │ Azure OpenAI │ │ PostgreSQL │ │ -│ │ - GPT-5-mini │ │ - pgvector │ │ -│ │ - text-embedding- │ │ - Zava database │ │ -│ │ 3-small │ │ │ │ -│ └─────────────────────────┘ └──────────────────────┘ │ +│ │ Entra ID │ │ +│ ▼ ▼ │ +│ ┌────────────────────────┐ ┌──────────────────────┐ │ +│ │ Azure OpenAI │ │ Postgres Flexible │ │ +│ │ gpt-5-mini │ │ Server + pgvector │ │ +│ │ text-embedding- │ │ Zava retail schema │ │ +│ │ ada-002 │ │ (~424 products) │ │ +│ └────────────────────────┘ └──────────────────────┘ │ └─────────────────────────────────────────────────────────────┘ - -Local Development: -┌──────────────┐ ┌──────────────┐ ┌──────────────┐ -│ agent.py │─HTTP─│ MCP Server │ │ Azure OpenAI │ -│ (localhost) │◄─────│ (localhost: │─────▶│ (cloud) │ -│ │ │ 8000) │ Entra│ │ -└──────────────┘ └──────────────┘ ID └──────────────┘ ``` -## Prerequisites - -### Cloud Deployment +The agent is the only public-facing service. The MCP server is reachable only from inside the Container Apps environment. -- **Azure Subscription** - [Create one for free](https://azure.microsoft.com/free/) -- **Azure Developer CLI (azd)** - [Install azd](https://learn.microsoft.com/azure/developer/azure-developer-cli/install-azd) -- **Azure CLI** - [Install Azure CLI](https://learn.microsoft.com/cli/azure/install-azure-cli) - -### Local Development - -- **Python 3.11+** - [Download Python](https://www.python.org/downloads/) -- **Docker Desktop** - [Install Docker](https://www.docker.com/products/docker-desktop) +## Prerequisites -### Quick Start (Recommended) +- An Azure subscription. [Create one for free](https://azure.microsoft.com/free/). +- [Azure Developer CLI (`azd`)](https://learn.microsoft.com/azure/developer/azure-developer-cli/install-azd). +- [Azure CLI](https://learn.microsoft.com/cli/azure/install-azure-cli). +- Python 3.11+ (only required for local development). +- Docker (only required for the full local stack). -- **GitHub Codespaces** - Click the badge above to start in a pre-configured environment with all tools installed! +The fastest path is to open the repo in **GitHub Codespaces** — every tool above is preinstalled. -## Quick Start +## Quick start -### Deploy to Azure (5 minutes) +Deploy the whole stack to Azure with one command: ```bash -# 1. Login to Azure az login azd auth login - -# 2. Deploy everything azd up ``` -That's it! The `azd up` command will: +`azd up` provisions Azure OpenAI (with `gpt-5-mini` and `text-embedding-ada-002`), a Postgres Flexible Server with pgvector, a Container Apps environment, and the two container images. After the build finishes a postprovision hook seeds the database with the Zava DIY catalogue (~424 products with pre-computed embeddings). -- Provision Azure OpenAI with GPT-5-mini -- Create Container Apps environment -- Build and deploy both the agent and MCP server containers -- Configure networking and managed identity -- Set up monitoring with Application Insights +When it finishes you'll see something like: -After deployment completes, you'll see output like: +```text +🚀 Your LangChain Agent is Ready! +🌐 Web chat: https://ca-agent-..azurecontainerapps.io/ + Health: https://ca-agent-..azurecontainerapps.io/api/health + MCP Server: https://ca-mcp-..azurecontainerapps.io/mcp ``` -SUCCESS: Your application was provisioned and deployed to Azure! -Endpoints: - - MCP Server: https://ca-mcp-abc123.region.azurecontainerapps.io - - Agent: https://ca-agent-abc123.region.azurecontainerapps.io -``` +Open the web chat URL and try: -### Local Development +- *What tables are in the database?* +- *Find me 3 hammers.* +- *Show sales by store as a pie chart.* -**Option 1: Use Azure Database (Recommended)** +To remove every resource later, run `azd down`. -```bash -# 1. Deploy to Azure first -azd up - -# 2. Get configuration and set MCP server URL -azd env get-values > .env.local -echo "MCP_SERVER_URL=http://localhost:8000" >> .env.local - -# 3. Start MCP server (Terminal 1) -cd mcp -source ../.env.local -python app.py - -# 4. Start agent server (Terminal 2) -cd agent -source ../.env.local -PORT=8001 python app.py +## Repository layout -# 5. Open browser to http://localhost:8001 +```text +. +├── agent/ # Public-facing chat service (LangChain + Responses API) +│ ├── app.py # Starlette app, lifespan, streaming /api/chat +│ ├── streaming.py # Pure parser for normalising LangChain stream chunks +│ ├── instructions.txt # System prompt for the agent +│ └── static/ # Single-page chat UI +├── mcp/ # Internal tool server (FastMCP) +│ └── app.py # 4 MCP tools over Postgres + pgvector +├── data/ # Pre-generated catalogue + seed scripts +├── infra/ # Bicep templates and parameters used by `azd up` +└── azure.yaml # azd service definitions and hooks ``` -**Option 2: Full Local Stack** +## How it works -```bash -# 1. Start PostgreSQL with pgvector -docker-compose up -d - -# 2. Configure environment -cp .env.example .env.local -# Edit .env.local with your Azure OpenAI credentials - -# 3. Initialize database -cd data -source ../.env.local -python generate_database.py +### 1. The agent — LangChain on the Responses API -# 4. Regenerate embeddings (required if your Azure OpenAI uses a different embedding model) -python regenerate_embeddings.py +`agent/app.py` builds the agent at startup inside a Starlette `lifespan` hook so the MCP connection and OpenAI credentials are reused across requests: -# 5. Start MCP server (Terminal 1) -cd mcp -source ../.env.local -python app.py - -# 6. Start agent server (Terminal 2) -cd agent -source ../.env.local -PORT=8001 python app.py +```python +mcp_tools = await MultiServerMCPClient( + {"zava-sales": {"url": MCP_SERVER_URL, "transport": "streamable_http"}} +).get_tools() + +server_tools = [ + {"type": "web_search_preview"}, + {"type": "code_interpreter", "container": {"type": "auto"}}, +] + +model = ChatOpenAI( + model=OPENAI_DEPLOYMENT, + base_url=OPENAI_ENDPOINT, + api_key=token_provider, # Entra ID — no API key + use_responses_api=True, + include=["code_interpreter_call.outputs"], +) -# 7. Open browser to http://localhost:8001 +agent = create_agent(model=model, tools=server_tools + mcp_tools, system_prompt=SYSTEM_PROMPT) ``` -**VS Code Tasks:** +A few things worth noting: -The project includes pre-configured VS Code tasks. Press `Cmd+Shift+P` (Mac) or `Ctrl+Shift+P` (Windows/Linux) and select "Tasks: Run Task" to see available tasks: -- Start MCP Server -- Start Agent -- Start PostgreSQL (Docker) -- Initialize Database +- `use_responses_api=True` opts into OpenAI's Responses API, which lets the model call **hosted** tools like `code_interpreter` and `web_search_preview` directly — no extra Python runtime needed for chart generation. +- `include=["code_interpreter_call.outputs"]` asks the API to stream the tool outputs (including any generated images) back inline. +- `api_key=token_provider` is a callable that returns a fresh Entra ID bearer token. There are no API keys anywhere in the stack. -**Ports:** +### 2. The MCP server — FastMCP over streamable HTTP -- MCP Server: `8000` -- Agent/Chat UI: `8001` (set via `PORT` environment variable) +`mcp/app.py` uses [FastMCP](https://github.com/jlowin/fastmcp) to expose four read-only tools to the agent: -## How It Works +| Tool | Purpose | +|------|---------| +| `get_current_utc_date` | Returns the current UTC time so the agent can interpret words like *"last quarter"* against a known anchor. | +| `get_table_schemas` | Returns the column definitions for every table in the `retail` schema. The agent reads this once before composing SQL. | +| `execute_sales_query` | Runs a parameterised read-only SQL query against Postgres. | +| `semantic_search_products` | Embeds the user's natural-language description with `text-embedding-ada-002` and runs a pgvector similarity search. | -### 1. **Agent (LangChain + Responses API)** - -The agent uses LangChain's `ChatOpenAI` with the new **Responses API** for native MCP tool support: +Tools are decorated with FastMCP annotations that tell the model what to expect: ```python -from langchain_openai import ChatOpenAI - -llm = ChatOpenAI( - model="gpt-5-mini", - base_url=f"{endpoint}/openai/v1/", - api_key=token_provider, - model_kwargs={"use_responses_api": True} -) - -# Bind MCP tools from server -mcp_tools = get_mcp_tools(mcp_server_url) -llm = llm.bind_tools(mcp_tools) +mcp = FastMCP("Zava Sales Analysis Tools", lifespan=lifespan) + +@mcp.tool(annotations={"title": "Semantic Product Search", "readOnlyHint": True}) +async def semantic_search_products( + query_description: Annotated[str, Field(description="Natural-language description of the product")], + threshold: float = 0.5, + max_rows: int = 10, +) -> list[dict]: + ... ``` -### 2. **MCP Server (FastMCP)** +The agent talks to this server over the `streamable_http` MCP transport — no shared library, just HTTP. That's what makes it easy to swap the MCP server out for one written in any other language. -The MCP server exposes tools via FastMCP: +### 3. Authentication — Entra ID end to end -```python -from fastmcp import FastMCP - -mcp = FastMCP("Data Analysis Tools") +Every cross-service hop uses Managed Identity: -@mcp.tool() -def execute_query(query: str) -> dict: - """Execute SQL query on database.""" - # ... implementation -``` +- The agent's container has a user-assigned identity granted **Cognitive Services User** on the Azure OpenAI account. +- The MCP server's container uses the same identity to authenticate to **Azure Database for PostgreSQL** and to **Azure OpenAI** (for embedding queries). +- There are no client secrets, connection strings with passwords, or API keys committed to the repo or stored in Container Apps env vars. -### 3. **Environment-based Configuration** +### 4. Infrastructure — Bicep + `azd` -Both services use environment-aware configuration: +`infra/main.bicep` provisions everything in a single deployment: -- **Local**: Uses `.env.local` file, connects to `localhost:8000` for MCP -- **Production**: Uses environment variables from Container Apps, connects via HTTPS to cloud MCP server +- Azure OpenAI account with two model deployments (chat + embeddings). +- Postgres Flexible Server with `pgvector` enabled and Entra ID auth on. +- Container Apps environment plus two Container Apps (`agent` and `mcp-server`). +- Log Analytics workspace and Application Insights for observability. -### 4. **Secure Authentication** +`azure.yaml` declares the two services, points them at their Dockerfiles, and registers a `postprovision` hook that creates the `retail` schema, loads the seed JSON files, and regenerates embeddings against whatever embedding model was actually deployed. -- **Azure OpenAI**: Uses Managed Identity (Entra ID) - no API keys -- **MCP Server**: Internal Container Apps networking -- **Monitoring**: Application Insights for observability +## Local development -## Available MCP Tools +You have two options. Both assume you've run `azd up` at least once so Azure OpenAI exists. -The MCP server provides these tools to the agent: +### Option 1 — Cloud Postgres, local services (recommended) -1. **`get_current_utc_date()`** - Returns current UTC timestamp for time-sensitive queries -2. **`get_table_schemas()`** - Returns PostgreSQL database schema information -3. **`execute_sales_query(query: str)`** - Executes SQL queries on PostgreSQL database -4. **`semantic_search_products(query_description: str)`** - Semantic product search using pgvector - -## Database +```bash +# Pull the deployed environment values +azd env get-values > .env.local +echo "MCP_SERVER_URL=http://localhost:8000" >> .env.local -The sample uses **Azure PostgreSQL Flexible Server** with **pgvector** for semantic product search. +# Terminal 1 — MCP server +cd mcp && source ../.env.local && python app.py -**Key Features:** +# Terminal 2 — agent +cd agent && source ../.env.local && PORT=8001 python app.py -- 10-table retail schema (products, orders, customers, inventory, etc.) -- Vector embeddings for semantic search using Azure OpenAI -- Pre-populated Zava DIY product catalog with ~424 products -- Natural language queries like "waterproof outdoor electrical boxes" +# Open http://localhost:8001 +``` -**Data Files Included:** +This runs both Python services on your machine but uses the cloud Postgres and Azure OpenAI deployments. -This repository includes pre-generated data files in the `data/` folder, so you don't need to download anything: -- `products_pregenerated.json` - 424 products with pre-computed embeddings -- `customers_pregenerated.json` - 500 sample customers -- `orders_pregenerated.json` - 2000 sample orders +### Option 2 — Full local stack -**Setup:** +```bash +docker compose up -d # local Postgres + pgvector +cp .env.example .env.local # add your Azure OpenAI endpoint +cd data && source ../.env.local && \ + python generate_database.py && \ + python regenerate_embeddings.py # match embeddings to your deployment +# Then start mcp/ and agent/ as in Option 1 +``` -- Production: Automatically provisioned during `azd up` -- Local: Run `docker-compose up -d` then `python data/generate_database.py` +VS Code tasks (`Cmd/Ctrl+Shift+P` → *Tasks: Run Task*) are pre-configured for **Start MCP Server**, **Start Agent**, **Start PostgreSQL (Docker)**, and **Initialize Database**. -## Customization +## Customise it -### Add New MCP Tools +### Add a new MCP tool -Edit `mcp/mcp_server.py`: +Add a function to `mcp/app.py` and decorate it. The agent will pick it up on the next start: ```python -@mcp.tool() -def my_custom_tool(param: str) -> dict: - """Description of what this tool does.""" - # Your implementation - return {"result": "data"} +@mcp.tool(annotations={"title": "Top Categories", "readOnlyHint": True}) +async def top_categories(limit: int = 5) -> list[dict]: + """Return the top-selling product categories.""" + rows = await db_provider.fetch( + "SELECT category, SUM(line_total) AS sales " + "FROM retail.order_items GROUP BY category " + "ORDER BY sales DESC LIMIT $1", limit, + ) + return [dict(r) for r in rows] ``` -The tool is automatically exposed via the `/tools` endpoint in OpenAI function format. - -### Change the Model +### Change the model Edit `infra/main.parameters.json`: ```json -{ - "openAiModelName": { - "value": "gpt-5-mini" // Change to gpt-4o, etc. - } -} +{ "openAiModelName": { "value": "gpt-5-mini" } } ``` -### Modify System Instructions +Use a model that supports the Responses API. Note that not every model supports every hosted tool — check the [Azure OpenAI model matrix](https://learn.microsoft.com/azure/ai-services/openai/concepts/models). -Edit `agent/instructions.txt` to change the agent's behavior and personality. +### Adjust agent behaviour -## Monitoring +`agent/instructions.txt` is the system prompt. It controls tone, when to call which tool, default assumptions about timeframes, chart preferences, and so on. Edit it and redeploy with `azd deploy agent`. -View logs and metrics in Azure Portal: +## Monitoring ```bash -# Open Application Insights -azd monitor - -# View container logs -az containerapp logs show --name --resource-group --follow +azd monitor # opens Application Insights +az containerapp logs show -n -g --follow # tail logs ``` -## Troubleshooting - -**"Deployment quota exceeded"** -→ Try a different region: `azd env set AZURE_LOCATION eastus2` - -**"Authentication failed"** -→ Ensure you're logged in: `az login && azd auth login` - -**"GPT-5-mini not available"** -→ Model may not be available in your region - try eastus or westus +Application Insights captures every request to `/api/chat`, every MCP tool call, and every Azure OpenAI request, with end-to-end traces. -**"Container apps failing to start"** -→ Check logs: `azd monitor` - -**"MCP tools not loading"** -→ Ensure MCP_SERVER_URL is accessible from the agent +## Troubleshooting -## Clean Up +| Symptom | Fix | +|---------|-----| +| `Deployment quota exceeded` | Set a different region: `azd env set AZURE_LOCATION eastus2` then re-run `azd up`. | +| `Authentication failed` | Re-login with `az login && azd auth login`. | +| `gpt-5-mini` not available in region | Try `eastus2`, `westus`, or `swedencentral`. Verify in the [Azure OpenAI model matrix](https://learn.microsoft.com/azure/ai-services/openai/concepts/models). | +| Container Apps not starting | `azd monitor` and inspect the *Revision* logs in the portal, or `az containerapp logs show`. | +| Agent loads no tools (`mcp_tool_count: 0`) | Check that `MCP_SERVER_URL` points at `https:///mcp` and that the MCP container is `Running`. | +| Semantic search returns nothing | The seed embeddings must be generated by the same model deployed in your Azure OpenAI account. Re-run `azd hooks run postprovision`. | -Remove all Azure resources: +## Clean up ```bash azd down ``` +This deletes the resource group and every resource provisioned by `azd up`. + ## Resources -- [Azure OpenAI Documentation](https://learn.microsoft.com/azure/ai-services/openai/) -- [LangChain Documentation](https://python.langchain.com/) -- [Model Context Protocol (MCP)](https://modelcontextprotocol.io/) -- [FastMCP Framework](https://github.com/jlowin/fastmcp) -- [Azure Developer CLI (azd)](https://learn.microsoft.com/azure/developer/azure-developer-cli/) -- [Original Workshop (WRK540)](https://github.com/microsoft/aitour26-WRK540-unlock-your-agents-potential-with-model-context-protocol) +- [Azure OpenAI Responses API](https://learn.microsoft.com/azure/ai-services/openai/how-to/responses) +- [LangChain](https://python.langchain.com/) and [`langchain-mcp-adapters`](https://github.com/langchain-ai/langchain-mcp-adapters) +- [Model Context Protocol](https://modelcontextprotocol.io/) and [FastMCP](https://github.com/jlowin/fastmcp) +- [Azure Developer CLI](https://learn.microsoft.com/azure/developer/azure-developer-cli/) +- [pgvector](https://github.com/pgvector/pgvector) +- This sample is inspired by the [Microsoft AI Tour WRK540 workshop](https://github.com/microsoft/aitour26-WRK540-unlock-your-agents-potential-with-model-context-protocol) and reuses its product catalogue. ## Contributing -This project welcomes contributions and suggestions. Most contributions require you to agree to a Contributor License Agreement (CLA). For details, visit https://cla.opensource.microsoft.com. +This project welcomes contributions. Most contributions require you to agree to a Contributor License Agreement; see [https://cla.opensource.microsoft.com](https://cla.opensource.microsoft.com). ## License -This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. +MIT — see [LICENSE](LICENSE). --- -**Questions or feedback?** Open an issue on [GitHub](https://github.com/Azure-Samples/langchain-agent-python/issues) or see [SUPPORT.md](SUPPORT.md). +Questions? Open an issue on [GitHub](https://github.com/Azure-Samples/langchain-agent-python/issues) or read [SUPPORT.md](SUPPORT.md). diff --git a/agent/Dockerfile b/agent/Dockerfile index 9bcc55c..845a24d 100644 --- a/agent/Dockerfile +++ b/agent/Dockerfile @@ -1,19 +1,24 @@ +# syntax=docker/dockerfile:1.7 FROM python:3.11-slim +# Install uv for fast, reproducible Python installs. +COPY --from=ghcr.io/astral-sh/uv:0.5.11 /uv /uvx /bin/ + WORKDIR /app +ENV PYTHONUNBUFFERED=1 \ + PYTHONDONTWRITEBYTECODE=1 \ + UV_SYSTEM_PYTHON=1 \ + UV_LINK_MODE=copy -# Copy requirements COPY requirements.txt . +RUN --mount=type=cache,target=/root/.cache/uv \ + uv pip install --system -r requirements.txt -# Install dependencies -RUN pip install --no-cache-dir -r requirements.txt - -# Copy application code COPY *.py . COPY *.txt . COPY static/ ./static/ EXPOSE 8000 -# Run with uvicorn CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"] + diff --git a/agent/app.py b/agent/app.py index 95487e2..30a1475 100644 --- a/agent/app.py +++ b/agent/app.py @@ -1,609 +1,240 @@ +"""Agent API ASGI application. + +A LangChain v1 agent backed by Azure OpenAI (Responses API) and an MCP +server reached over `langchain-mcp-adapters`. One unified code path serves +both local development and Azure Container Apps deployments. + +Heavy initialisation (Azure credential, model, MCP client + tools, agent) +happens once in the Starlette `lifespan` and is reused across requests. """ -Agent API ASGI Application -LangChain agent with MCP tool support, using Azure OpenAI -Runs with uvicorn on Azure Container Apps -""" +from __future__ import annotations + +import asyncio import json import logging import os +from contextlib import asynccontextmanager from pathlib import Path +from typing import Any from dotenv import load_dotenv -from langchain_core.tools import tool -# Load environment variables from .env.local (for local development) -env_path = Path(__file__).parent.parent / ".env.local" -load_dotenv(env_path) +# Load .env.local before importing anything that reads env at import time. +load_dotenv(Path(__file__).parent.parent / ".env.local") from azure.identity import DefaultAzureCredential, get_bearer_token_provider from langchain.agents import create_agent +from langchain_mcp_adapters.client import MultiServerMCPClient from langchain_openai import ChatOpenAI from starlette.applications import Starlette from starlette.responses import FileResponse, JSONResponse, StreamingResponse from starlette.routing import Route -# Configure logging +from streaming import iter_message_events + logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -# Determine environment -environment = os.getenv("ENVIRONMENT", "production") -is_local = environment == "local" +# ---- Configuration --------------------------------------------------------- +ENVIRONMENT = os.getenv("ENVIRONMENT", "production") + +OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT", "").rstrip("/") +if OPENAI_ENDPOINT and not OPENAI_ENDPOINT.endswith("/openai/v1"): + OPENAI_ENDPOINT = f"{OPENAI_ENDPOINT}/openai/v1" -# Load system instructions once at module level -system_instructions_path = Path(__file__).parent / "instructions.txt" -with open(system_instructions_path, "r") as f: - system_prompt = f.read().strip() +OPENAI_DEPLOYMENT = os.getenv("AZURE_OPENAI_DEPLOYMENT", "gpt-5-mini") -# Configuration from environment -openai_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT", "").rstrip("/") -if openai_endpoint and not openai_endpoint.endswith("/openai/v1"): - openai_endpoint = f"{openai_endpoint}/openai/v1" +MCP_SERVER_URL = os.getenv("MCP_SERVER_URL", "http://localhost:8000").rstrip("/") +if not MCP_SERVER_URL.endswith("/mcp"): + MCP_SERVER_URL = f"{MCP_SERVER_URL}/mcp" -openai_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT", "gpt-5-mini") +# Image generation pairs `image_generation` (server-side) with custom MCP +# tools, which historically tripped a partial_images mutation bug +# (langchain-ai/langchain#34136). We keep it behind an env flag so the +# default deploy is conservative and operators can opt in once they've +# verified their langchain stack version. +ENABLE_IMAGE_GENERATION = os.getenv("ENABLE_IMAGE_GENERATION", "false").lower() in ("1", "true", "yes") -mcp_server_url = os.getenv("MCP_SERVER_URL", "http://localhost:8000").rstrip("/") -if not mcp_server_url.endswith("/mcp"): - mcp_server_url = f"{mcp_server_url}/mcp" +with open(Path(__file__).parent / "instructions.txt", "r") as fh: + SYSTEM_PROMPT = fh.read().strip() -# For local mode, we'll load MCP tools via langchain-mcp-adapters -mcp_tools = [] -mcp_client = None -if is_local: - logger.info("🔧 Running in LOCAL mode - using langchain-mcp-adapters for MCP tools") - from langchain_mcp_adapters.client import MultiServerMCPClient +# ---- Lifespan: build the agent once at startup ----------------------------- +async def _connect_mcp_with_retry(client: MultiServerMCPClient, attempts: int = 5) -> list: + """Fetch MCP tools, retrying with exponential backoff on transient errors. + + Container Apps may start the agent before the MCP service is reachable; + we want a few retries before crash-looping the container. + """ + delay = 1.0 + last_exc: Exception | None = None + for i in range(1, attempts + 1): + try: + tools = await client.get_tools() + logger.info("📦 Loaded %d MCP tool(s) from %s", len(tools), MCP_SERVER_URL) + return tools + except Exception as exc: + last_exc = exc + logger.warning("MCP get_tools attempt %d/%d failed: %s", i, attempts, exc) + if i < attempts: + await asyncio.sleep(delay) + delay *= 2 + raise RuntimeError(f"Could not reach MCP server at {MCP_SERVER_URL}") from last_exc + + +@asynccontextmanager +async def lifespan(app: Starlette): + logger.info("Initialising agent (env=%s, mcp=%s)…", ENVIRONMENT, MCP_SERVER_URL) + + credential = DefaultAzureCredential() + token_provider = get_bearer_token_provider(credential, "https://cognitiveservices.azure.com/.default") - # Create MCP client for local server mcp_client = MultiServerMCPClient( - { - "zava-sales": { - "url": mcp_server_url, - "transport": "streamable_http", - } - } + {"zava-sales": {"url": MCP_SERVER_URL, "transport": "streamable_http"}} ) -else: - logger.info( - "🚀 Running in PRODUCTION mode - using Azure OpenAI Responses API for MCP" + mcp_tools = await _connect_mcp_with_retry(mcp_client) + + server_tools: list[dict[str, Any]] = [ + {"type": "web_search_preview"}, + {"type": "code_interpreter", "container": {"type": "auto"}}, + ] + if ENABLE_IMAGE_GENERATION: + server_tools.append({"type": "image_generation", "quality": "low"}) + logger.info("🎨 image_generation enabled") + + model = ChatOpenAI( + model=OPENAI_DEPLOYMENT, + base_url=OPENAI_ENDPOINT, + api_key=token_provider, + streaming=True, + use_responses_api=True, + include=["code_interpreter_call.outputs"], ) + agent = create_agent( + model=model, + tools=server_tools + mcp_tools, + system_prompt=SYSTEM_PROMPT, + ) + + app.state.agent = agent + app.state.mcp_tool_count = len(mcp_tools) + app.state.image_generation_enabled = ENABLE_IMAGE_GENERATION + logger.info("✅ Agent ready") -async def chat_ui_endpoint(request): - """Serve the chat UI.""" try: - html_path = Path(__file__).parent / "static" / "index.html" - return FileResponse(html_path, media_type="text/html") - except Exception as e: - logger.error(f"Error loading chat UI: {e}", exc_info=True) - return JSONResponse( - {"error": f"Error loading chat UI: {str(e)}"}, status_code=500 - ) + yield + finally: + try: + credential.close() + except Exception: + pass + + +# ---- Endpoints ------------------------------------------------------------- +async def chat_ui_endpoint(request): + return FileResponse(Path(__file__).parent / "static" / "index.html", media_type="text/html") async def chat_endpoint(request): - """ - Chat endpoint for the agent with streaming support. + agent = getattr(request.app.state, "agent", None) + if agent is None: + return JSONResponse({"error": "Agent is not ready yet. Try again in a few seconds."}, status_code=503) - Request body: - { - "message": "user message", - "history": [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}] - } - """ try: - # Parse request body - req_body = await request.json() - message = req_body.get("message") - history = req_body.get("history", []) - - if not message: - return JSONResponse({"error": "message is required"}, status_code=400) - - # Initialize Azure credential and token provider - credential = DefaultAzureCredential() - token_provider = get_bearer_token_provider( - credential, "https://cognitiveservices.azure.com/.default" + body = await request.json() + except json.JSONDecodeError: + return JSONResponse({"error": "Invalid JSON body"}, status_code=400) + + message = body.get("message") + history = body.get("history", []) or [] + if not message: + return JSONResponse({"error": "message is required"}, status_code=400) + + messages = [{"role": m["role"], "content": m["content"]} for m in history if m.get("role")] + messages.append({"role": "user", "content": message}) + + async def generate_stream(): + full_text: list[str] = [] + images: list[dict] = [] + status_active = False + + async def encode(obj: dict) -> str: + return json.dumps(obj) + "\n" + + try: + async for chunk in agent.astream({"messages": messages}, stream_mode="messages"): + msg = chunk[0] if isinstance(chunk, tuple) else chunk + + for ev in iter_message_events(msg): + kind = ev["kind"] + if kind == "text": + if status_active: + yield await encode({"status": ""}) + status_active = False + full_text.append(ev["text"]) + yield await encode({"chunk": ev["text"]}) + elif kind == "image": + if status_active: + yield await encode({"status": ""}) + status_active = False + images.append(ev["image"]) + yield await encode({"image": ev["image"]}) + elif kind == "status_start": + if not status_active: + yield await encode({"status": ev["status"]}) + status_active = True + elif kind == "status_end": + if status_active: + yield await encode({"status": ""}) + status_active = False + except Exception as exc: + logger.exception("Error during agent stream") + yield await encode({"error": f"agent stream failed: {exc}"}) + + if status_active: + yield await encode({"status": ""}) + yield await encode( + { + "message": "".join(full_text), + "role": "assistant", + "images": images, + "done": True, + } ) - # Build tools list - differs between local and production - if is_local: - # LOCAL MODE: Use a separate image agent to avoid the partial_images - # mutation bug when combining image_generation with custom tools. - # See: https://github.com/langchain-ai/langchain/pull/34136 - - # Create a dedicated image generation agent (no custom tools = no bug) - image_model = ChatOpenAI( - model=openai_deployment, - base_url=openai_endpoint, - api_key=token_provider, - temperature=0.7, - streaming=True, - use_responses_api=True, - ) - image_agent = create_agent( - model=image_model, - tools=[{"type": "image_generation", "quality": "low"}], - system_prompt="You are an image generation assistant. Generate images based on the user's description. Be creative and descriptive.", - ) - - # Wrap the image agent as a tool for the main agent - @tool - async def generate_image(description: str) -> dict: - """Generate an image based on a text description. Use this when the user asks you to create, draw, or generate an image. Returns a dictionary with image data.""" - result = await image_agent.ainvoke( - {"messages": [{"role": "user", "content": description}]} - ) - # Extract image data from the response - last_message = result["messages"][-1] - content = last_message.content - - # Parse the content to find image blocks - if isinstance(content, list): - for block in content: - if isinstance(block, dict) and block.get("type") == "image": - return { - "type": "image", - "base64": block.get("base64", ""), - "format": block.get("format", "png"), - } - elif hasattr(block, "type") and block.type == "image": - return { - "type": "image", - "base64": getattr(block, "base64", ""), - "format": getattr(block, "format", "png"), - } - - # If no image block found, return the text content - return {"type": "text", "content": str(content)} - - # Responses API tools (without image_generation - it's now a subagent) - responses_api_tools = [ - {"type": "web_search_preview"}, - {"type": "code_interpreter", "container": {"type": "auto"}}, - ] - # Get MCP tools via langchain-mcp-adapters - mcp_tools = await mcp_client.get_tools() - logger.info(f"📦 Loaded {len(mcp_tools)} MCP tools for local mode") - # Combine: Responses API tools + image tool (subagent) + MCP tools - all_tools = responses_api_tools + [generate_image] + mcp_tools - else: - # PRODUCTION MODE: All tools work via Azure OpenAI remote handling - all_tools = [ - { - "type": "mcp", - "server_label": "zava-sales", - "server_url": mcp_server_url, - "require_approval": "never", - }, - {"type": "web_search_preview"}, - {"type": "image_generation", "quality": "low"}, - {"type": "code_interpreter", "container": {"type": "auto"}}, - ] - - # Create model with all tools bound - model = ChatOpenAI( - model=openai_deployment, - base_url=openai_endpoint, - api_key=token_provider, - temperature=0.7, - streaming=True, - use_responses_api=True, - include=["code_interpreter_call.outputs"], - ) - - # Create agent with the same tools - agent = create_agent(model=model, tools=all_tools, system_prompt=system_prompt) - - # Build messages for agent - messages = [] - - # Add history - if history: - for msg in history: - messages.append({"role": msg["role"], "content": msg["content"]}) - - # Add current message - messages.append({"role": "user", "content": message}) - - # Helper function to get appropriate status message based on tool name - def get_tool_status(tool_names: list) -> str: - """Return appropriate status message based on the tool being called.""" - for name in tool_names: - name_lower = name.lower() if name else "" - # Check for our MCP tools first (more specific matches) - if "semantic_search" in name_lower: - return "🔍 Searching products..." - elif ( - "execute_sales_query" in name_lower - or "get_table_schemas" in name_lower - ): - return "🔍 Querying database..." - elif "get_current_utc_date" in name_lower: - return "⏰ Getting current time..." - # Then check for built-in tools - elif "image" in name_lower or "generate_image" in name_lower: - return "🎨 Generating image..." - elif "web_search" in name_lower: - return "🔎 Searching the web..." - elif "code_interpreter" in name_lower or "code" in name_lower: - return "💻 Running code..." - elif any( - db_term in name_lower - for db_term in [ - "query", - "sql", - "database", - "db", - "sales", - "customer", - "order", - "product", - ] - ): - return "🔍 Querying database..." - # Default status for unknown tools - return f"⚙️ Using {tool_names[0] if tool_names else 'tool'}..." - - # Async generator for true streaming - async def generate_stream(): - """Stream chunks as they arrive from the agent.""" - full_response = "" - images = [] - tool_in_progress = False - - # Stream with stream_mode="messages" to get token-by-token output - async for chunk in agent.astream( - {"messages": messages}, stream_mode="messages" - ): - # Handle different chunk formats - if isinstance(chunk, tuple): - token, metadata = chunk - else: - token = chunk - - # Skip tool calls and tool results - only show AI responses - # Check message type - msg_type = getattr(token, "type", None) - if msg_type in ("tool", "function"): - # This is a tool result - check for images from code_interpreter - tool_content = getattr(token, "content", "") - if tool_content and isinstance(tool_content, str): - # Check for base64 image data in tool output - if "base64" in tool_content and ( - "image" in tool_content or "png" in tool_content - ): - try: - tool_data = json.loads(tool_content) - if ( - isinstance(tool_data, dict) - and tool_data.get("type") == "image" - ): - image_data = { - "base64": tool_data.get("base64", ""), - "format": tool_data.get("format", "png"), - } - images.append(image_data) - yield json.dumps({"image": image_data}) + "\n" - except json.JSONDecodeError: - pass - continue - - # Check if this is a tool call message - if hasattr(token, "tool_calls") and token.tool_calls: - # AI is calling a tool - send appropriate status update - if not tool_in_progress: - tool_in_progress = True - tool_names = [ - tc.get("name", "tool") - if isinstance(tc, dict) - else getattr(tc, "name", "tool") - for tc in token.tool_calls - ] - status_msg = get_tool_status(tool_names) - yield json.dumps({"status": status_msg}) + "\n" - continue - - # Check for additional_kwargs with tool_calls - if hasattr(token, "additional_kwargs"): - if token.additional_kwargs.get("tool_calls"): - if not tool_in_progress: - tool_in_progress = True - # Extract tool names from additional_kwargs - tool_calls = token.additional_kwargs.get("tool_calls", []) - tool_names = [ - tc.get("function", {}).get("name", "tool") - if isinstance(tc, dict) - else getattr(tc, "name", "tool") - for tc in tool_calls - ] - status_msg = get_tool_status(tool_names) - yield json.dumps({"status": status_msg}) + "\n" - continue - - # Reset tool status when we get actual content - if tool_in_progress: - tool_in_progress = False - yield json.dumps({"status": ""}) + "\n" # Clear status - - # Only process AI messages with actual content - if hasattr(token, "content"): - content = token.content - - # Skip empty content - if not content: - continue - - # Check response_metadata for code_interpreter outputs (contains images) - if hasattr(token, "response_metadata"): - resp_meta = token.response_metadata - if isinstance(resp_meta, dict): - # Check for code_interpreter output in response - outputs = resp_meta.get("code_interpreter_call", {}).get( - "outputs", [] - ) - if not outputs: - outputs = resp_meta.get("outputs", []) - for output in outputs: - if isinstance(output, dict): - # Handle file outputs (like PNG images) - if output.get("type") == "files": - for file_info in output.get("files", []): - if "image" in file_info.get( - "mime_type", "" - ): - # The file content is base64 encoded - image_data = { - "base64": file_info.get( - "file_data", "" - ), - "format": file_info.get( - "mime_type", "image/png" - ).split("/")[-1], - } - images.append(image_data) - yield ( - json.dumps({"image": image_data}) - + "\n" - ) - # Handle image type outputs directly - elif output.get("type") == "image": - image_data = { - "base64": output.get( - "base64", output.get("data", "") - ), - "format": output.get("format", "png"), - } - images.append(image_data) - yield json.dumps({"image": image_data}) + "\n" - - # Handle content_blocks (LangChain Responses API format) - # See: https://docs.langchain.com/oss/python/langchain/messages#message-content - if isinstance(content, list): - for block in content: - if isinstance(block, dict): - block_type = block.get("type") - - # Text content block - if block_type == "text": - text = block.get("text", "") - if text: - full_response += text - yield json.dumps({"chunk": text}) + "\n" - - # Reasoning block (from reasoning models) - elif block_type == "reasoning": - # Skip reasoning blocks - don't show to user - pass - - # Server tool call (web_search, file_search, etc.) - elif block_type == "server_tool_call": - tool_name = block.get("name", "tool") - if not tool_in_progress: - tool_in_progress = True - status_msg = get_tool_status([tool_name]) - yield json.dumps({"status": status_msg}) + "\n" - - # Server tool result - elif block_type == "server_tool_result": - # Tool completed, reset status - if tool_in_progress: - tool_in_progress = False - yield json.dumps({"status": ""}) + "\n" - - # Code interpreter call with outputs - elif block_type == "code_interpreter_call": - outputs = block.get("outputs", []) - - # Show status when code interpreter is seen (before outputs) - # Only show if we haven't shown status yet AND no outputs yet - if not tool_in_progress and not outputs: - tool_in_progress = True - yield ( - json.dumps({"status": "💻 Running code..."}) - + "\n" - ) - - # If we have outputs, clear status first then process - if outputs: - if tool_in_progress: - tool_in_progress = False - yield json.dumps({"status": ""}) + "\n" - for output in outputs: - if isinstance(output, dict): - if output.get("type") == "image": - # Image can be in 'url' as data URI or direct 'base64' - url = output.get("url", "") - if url.startswith("data:image/"): - # Parse data URI: data:image/png;base64,XXXX - parts = url.split(",", 1) - if len(parts) == 2: - # Extract format from mime type - mime_part = parts[0] - b64_data = parts[1] - img_format = "png" - if "image/" in mime_part: - img_format = ( - mime_part.split( - "image/" - )[1].split(";")[0] - ) - image_data = { - "base64": b64_data, - "format": img_format, - } - images.append(image_data) - yield ( - json.dumps( - {"image": image_data} - ) - + "\n" - ) - else: - # Direct base64 - b64 = output.get( - "base64", output.get("data", "") - ) - if b64: - image_data = { - "base64": b64, - "format": output.get( - "format", "png" - ), - } - images.append(image_data) - yield ( - json.dumps( - {"image": image_data} - ) - + "\n" - ) - - # Direct image block (from image_generation) - elif block_type == "image": - url = block.get("url", "") - if url.startswith("data:image/"): - parts = url.split(",", 1) - if len(parts) == 2: - mime_part = parts[0] - b64_data = parts[1] - img_format = "png" - if "image/" in mime_part: - img_format = mime_part.split("image/")[ - 1 - ].split(";")[0] - image_data = { - "base64": b64_data, - "format": img_format, - } - images.append(image_data) - yield ( - json.dumps({"image": image_data}) + "\n" - ) - else: - b64 = block.get("base64", "") - if b64: - image_data = { - "base64": b64, - "format": block.get("format", "png"), - } - images.append(image_data) - yield ( - json.dumps({"image": image_data}) + "\n" - ) - - # Handle object-style blocks (older format) - elif hasattr(block, "text"): - text = block.text - if text: - full_response += text - yield json.dumps({"chunk": text}) + "\n" - elif ( - hasattr(block, "type") - and getattr(block, "type", None) == "image" - ): - image_data = { - "base64": getattr( - block, "base64", getattr(block, "data", "") - ), - "format": getattr(block, "format", "png"), - } - if image_data["base64"]: - images.append(image_data) - yield json.dumps({"image": image_data}) + "\n" - elif isinstance(content, str) and content: - # Check if content contains image data from tool result - if content.startswith("{") and '"type": "image"' in content: - try: - img_data = json.loads(content) - if img_data.get("type") == "image": - image_data = { - "base64": img_data.get("base64", ""), - "format": img_data.get("format", "png"), - } - images.append(image_data) - yield json.dumps({"image": image_data}) + "\n" - continue - except json.JSONDecodeError: - pass - full_response += content - yield json.dumps({"chunk": content}) + "\n" - - # Send final complete message - yield ( - json.dumps( - { - "message": full_response, - "role": "assistant", - "images": images, - "done": True, - } - ) - + "\n" - ) - - return StreamingResponse(generate_stream(), media_type="application/json") - - except ValueError as e: - logger.error(f"ValueError in chat endpoint: {e}", exc_info=True) - return JSONResponse({"error": str(e)}, status_code=400) - except Exception as e: - logger.error(f"Error in chat endpoint: {e}", exc_info=True) - return JSONResponse( - {"error": f"Internal server error: {str(e)}"}, status_code=500 - ) + return StreamingResponse(generate_stream(), media_type="application/json") async def health_endpoint(request): - """Health check endpoint.""" - try: - return JSONResponse( - { - "status": "healthy", - "environment": environment, - "openai_endpoint": openai_endpoint, - "mcp_server": mcp_server_url, - } - ) - except Exception as e: - logger.error(f"Health check error: {e}") - return JSONResponse({"status": "unhealthy", "error": str(e)}, status_code=503) + state = request.app.state + ready = hasattr(state, "agent") and state.agent is not None + return JSONResponse( + { + "status": "healthy" if ready else "starting", + "ready": ready, + "environment": ENVIRONMENT, + "openai_endpoint": OPENAI_ENDPOINT, + "mcp_server": MCP_SERVER_URL, + "mcp_tool_count": getattr(state, "mcp_tool_count", 0), + "image_generation_enabled": getattr(state, "image_generation_enabled", False), + }, + status_code=200 if ready else 503, + ) -# Define routes +# ---- App ------------------------------------------------------------------- routes = [ Route("/", chat_ui_endpoint, methods=["GET"]), Route("/api/chat", chat_endpoint, methods=["POST"]), Route("/api/health", health_endpoint, methods=["GET"]), ] -# Create Starlette app -app = Starlette(debug=False, routes=routes) +app = Starlette(debug=False, routes=routes, lifespan=lifespan) if __name__ == "__main__": import uvicorn - port = int(os.getenv("PORT", "8000")) - uvicorn.run(app, host="0.0.0.0", port=port) + uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", "8000"))) diff --git a/agent/requirements.txt b/agent/requirements.txt index 8a23c84..388f1b3 100644 --- a/agent/requirements.txt +++ b/agent/requirements.txt @@ -1,9 +1,9 @@ -starlette>=0.27.0 -uvicorn>=0.23.0 -langchain>=0.3.0 -langchain-openai>=0.3.29 -langchain-core>=0.3.0 -langchain-mcp-adapters>=0.1.0 +starlette>=0.40.0 +uvicorn>=0.30.0 +langchain>=1.0 +langchain-openai>=1.0 +langchain-core>=1.0 +langchain-mcp-adapters>=0.2.2 azure-identity>=1.25.1 python-dotenv>=1.0.0 pydantic>=2.0.0 diff --git a/agent/streaming.py b/agent/streaming.py new file mode 100644 index 0000000..8ea18e7 --- /dev/null +++ b/agent/streaming.py @@ -0,0 +1,214 @@ +""" +Helpers for converting a streamed LangChain agent message into the NDJSON +event stream used by the chat UI. + +The chat API emits one JSON object per line. Event kinds: + + {"chunk": "text"} # each token of assistant text + {"status": "..."} # tool announcement (empty = clear) + {"image": {"base64": "...", "format": "png"}} + {"message": "...", "role": "assistant", "images": [...], "done": true} + +`iter_message_events` is a pure function over a single AIMessageChunk-like +object. It yields zero or more event dicts. The caller is responsible for +serialising them and managing the running `full_text` and `images` lists. + +Keeping this module free of network I/O makes it cheap to unit-test against +hand-built chunks that mimic real LangChain v1 streaming shapes. +""" + +from __future__ import annotations + +import json +from typing import Any, Iterable + + +def tool_status_for(tool_names: list[str]) -> str: + """Map a list of tool names to a friendly status string.""" + for name in tool_names: + n = (name or "").lower() + if "semantic_search" in n: + return "🔍 Searching products..." + if "execute_sales_query" in n or "get_table_schemas" in n: + return "🔍 Querying database..." + if "get_current_utc_date" in n: + return "⏰ Getting current time..." + if "image" in n or "generate_image" in n: + return "🎨 Generating image..." + if "web_search" in n: + return "🔎 Searching the web..." + if "code_interpreter" in n or "code" in n: + return "💻 Running code..." + if any(t in n for t in ("query", "sql", "database", "db", "sales", "customer", "order", "product")): + return "🔍 Querying database..." + return f"⚙️ Using {tool_names[0] if tool_names else 'tool'}..." + + +def _extract_image(block: dict) -> dict | None: + """Pull a {base64, format} image dict out of a content block in any of the + shapes produced by LangChain / OpenAI Responses API.""" + url = block.get("url", "") or "" + if isinstance(url, str) and url.startswith("data:image/"): + try: + mime, data = url.split(",", 1) + except ValueError: + return None + fmt = "png" + if "image/" in mime: + fmt = mime.split("image/", 1)[1].split(";", 1)[0] or "png" + return {"base64": data, "format": fmt} + + b64 = block.get("base64") or block.get("data") or block.get("file_data") + if b64: + fmt = block.get("format") or block.get("mime_type", "image/png").split("/")[-1] + return {"base64": b64, "format": fmt or "png"} + return None + + +def _iter_blocks(blocks: Any) -> Iterable[dict]: + """Yield events for an iterable of content blocks (dict or object).""" + if not isinstance(blocks, list): + return + for b in blocks: + if isinstance(b, dict): + t = b.get("type") + if t == "text": + text = b.get("text") or "" + if text: + yield {"kind": "text", "text": text} + elif t == "image": + img = _extract_image(b) + if img: + yield {"kind": "image", "image": img} + elif t == "server_tool_call": + yield {"kind": "status_start", "status": tool_status_for([b.get("name", "")])} + elif t == "server_tool_result": + yield {"kind": "status_end"} + elif t == "code_interpreter_call": + outputs = b.get("outputs") or [] + if not outputs: + yield {"kind": "status_start", "status": "💻 Running code..."} + continue + yield {"kind": "status_end"} + for o in outputs: + if isinstance(o, dict) and o.get("type") == "image": + img = _extract_image(o) + if img: + yield {"kind": "image", "image": img} + elif isinstance(o, dict) and o.get("type") == "files": + for f in o.get("files", []): + if isinstance(f, dict) and "image" in (f.get("mime_type") or ""): + img = _extract_image(f) + if img: + yield {"kind": "image", "image": img} + # reasoning/tool_use/etc → ignored on purpose + else: + text = getattr(b, "text", None) + if text: + yield {"kind": "text", "text": text} + elif getattr(b, "type", None) == "image": + img = { + "base64": getattr(b, "base64", "") or getattr(b, "data", ""), + "format": getattr(b, "format", "png"), + } + if img["base64"]: + yield {"kind": "image", "image": img} + + +def _tool_names_from_chunk(msg: Any) -> list[str]: + """Extract tool names from `tool_calls` (langchain) or `additional_kwargs` (raw OpenAI).""" + names: list[str] = [] + tool_calls = getattr(msg, "tool_calls", None) or [] + for tc in tool_calls: + if isinstance(tc, dict): + n = tc.get("name") or tc.get("function", {}).get("name") + else: + n = getattr(tc, "name", None) + if n: + names.append(n) + if names: + return names + extras = getattr(msg, "additional_kwargs", None) or {} + for tc in extras.get("tool_calls", []) or []: + if isinstance(tc, dict): + n = tc.get("name") or tc.get("function", {}).get("name") + if n: + names.append(n) + return names + + +def _maybe_image_from_tool_string(content: str) -> dict | None: + """Some custom MCP tools return a JSON-stringified image dict.""" + if not isinstance(content, str) or '"type"' not in content[:100] or "image" not in content[:100]: + return None + try: + data = json.loads(content) + except json.JSONDecodeError: + return None + if isinstance(data, dict) and data.get("type") == "image": + b64 = data.get("base64") or data.get("data") + if b64: + return {"base64": b64, "format": data.get("format", "png")} + return None + + +def iter_message_events(msg: Any) -> Iterable[dict]: + """Yield event dicts for one streamed message chunk. + + Event kinds: + {"kind": "text", "text": str} + {"kind": "image", "image": {...}} + {"kind": "status_start", "status": str} + {"kind": "status_end"} # caller emits {"status": ""} + """ + msg_type = getattr(msg, "type", None) + + # Tool/function results: only interesting when a custom MCP tool returns + # a JSON-encoded image dict. Otherwise we ignore them - the AI message's + # content_blocks already reflect what the user should see. + if msg_type in ("tool", "function"): + content = getattr(msg, "content", "") or "" + img = _maybe_image_from_tool_string(content) if isinstance(content, str) else None + if img: + yield {"kind": "image", "image": img} + return + + # AI requesting tools - announce status (layered: tool_calls then additional_kwargs). + tool_names = _tool_names_from_chunk(msg) + if tool_names: + yield {"kind": "status_start", "status": tool_status_for(tool_names)} + return + + # response_metadata may carry code_interpreter outputs out-of-band. + rmeta = getattr(msg, "response_metadata", None) + if isinstance(rmeta, dict): + outputs = rmeta.get("code_interpreter_call", {}).get("outputs") or rmeta.get("outputs") or [] + for o in outputs: + if isinstance(o, dict): + if o.get("type") == "files": + for f in o.get("files", []): + if isinstance(f, dict) and "image" in (f.get("mime_type") or ""): + img = _extract_image(f) + if img: + yield {"kind": "image", "image": img} + elif o.get("type") == "image": + img = _extract_image(o) + if img: + yield {"kind": "image", "image": img} + + # Standard LC v1 normalized blocks (preferred path). + blocks = getattr(msg, "content_blocks", None) + if blocks: + yield from _iter_blocks(blocks) + return + + # Fallback: raw content (may be a list of provider blocks or a plain string). + content = getattr(msg, "content", None) + if isinstance(content, list): + yield from _iter_blocks(content) + elif isinstance(content, str) and content: + img = _maybe_image_from_tool_string(content) + if img: + yield {"kind": "image", "image": img} + else: + yield {"kind": "text", "text": content} diff --git a/agent/test_streaming.py b/agent/test_streaming.py new file mode 100644 index 0000000..d4e5052 --- /dev/null +++ b/agent/test_streaming.py @@ -0,0 +1,154 @@ +"""Contract tests for `streaming.iter_message_events`. + +These tests build hand-crafted message chunks that mimic the shapes a +LangChain v1 / Azure OpenAI Responses API agent produces during streaming, +then assert the public NDJSON event contract is preserved. +""" + +from __future__ import annotations + +import sys +from pathlib import Path +from types import SimpleNamespace + +sys.path.insert(0, str(Path(__file__).parent)) + +from streaming import iter_message_events, tool_status_for + + +def chunk(**kwargs): + """Build a SimpleNamespace pretending to be an AIMessageChunk.""" + defaults = {"type": "ai", "content": "", "tool_calls": []} + defaults.update(kwargs) + return SimpleNamespace(**defaults) + + +# ---------- text streaming --------------------------------------------------- +def test_text_via_content_blocks(): + msg = chunk(content_blocks=[{"type": "text", "text": "Hel"}]) + events = list(iter_message_events(msg)) + assert events == [{"kind": "text", "text": "Hel"}] + + +def test_text_via_raw_string_content(): + msg = chunk(content="Hello") + events = list(iter_message_events(msg)) + assert events == [{"kind": "text", "text": "Hello"}] + + +def test_text_via_raw_list_content_when_no_content_blocks(): + msg = chunk(content=[{"type": "text", "text": "world"}]) + events = list(iter_message_events(msg)) + assert events == [{"kind": "text", "text": "world"}] + + +def test_empty_text_block_skipped(): + msg = chunk(content_blocks=[{"type": "text", "text": ""}]) + assert list(iter_message_events(msg)) == [] + + +def test_reasoning_block_is_ignored(): + msg = chunk(content_blocks=[{"type": "reasoning", "text": "thinking..."}]) + assert list(iter_message_events(msg)) == [] + + +# ---------- tool calls ------------------------------------------------------- +def test_tool_call_emits_status_start(): + msg = chunk(tool_calls=[{"name": "semantic_search_products", "args": {}}]) + events = list(iter_message_events(msg)) + assert events == [{"kind": "status_start", "status": "🔍 Searching products..."}] + + +def test_tool_call_via_additional_kwargs(): + msg = chunk(additional_kwargs={"tool_calls": [{"function": {"name": "execute_sales_query"}}]}) + events = list(iter_message_events(msg)) + assert events == [{"kind": "status_start", "status": "🔍 Querying database..."}] + + +def test_server_tool_call_block(): + msg = chunk(content_blocks=[{"type": "server_tool_call", "name": "web_search"}]) + events = list(iter_message_events(msg)) + assert events == [{"kind": "status_start", "status": "🔎 Searching the web..."}] + + +def test_server_tool_result_block(): + msg = chunk(content_blocks=[{"type": "server_tool_result"}]) + events = list(iter_message_events(msg)) + assert events == [{"kind": "status_end"}] + + +# ---------- code interpreter ------------------------------------------------- +def test_code_interpreter_call_no_outputs_announces_status(): + msg = chunk(content_blocks=[{"type": "code_interpreter_call", "outputs": []}]) + assert list(iter_message_events(msg)) == [ + {"kind": "status_start", "status": "💻 Running code..."}, + ] + + +def test_code_interpreter_with_image_data_url(): + msg = chunk( + content_blocks=[ + { + "type": "code_interpreter_call", + "outputs": [{"type": "image", "url": "data:image/png;base64,AAA"}], + } + ] + ) + events = list(iter_message_events(msg)) + assert events == [ + {"kind": "status_end"}, + {"kind": "image", "image": {"base64": "AAA", "format": "png"}}, + ] + + +def test_response_metadata_code_interpreter_files(): + msg = chunk( + response_metadata={ + "code_interpreter_call": { + "outputs": [ + {"type": "files", "files": [{"mime_type": "image/png", "file_data": "ZZZ"}]} + ] + } + } + ) + events = list(iter_message_events(msg)) + assert events == [{"kind": "image", "image": {"base64": "ZZZ", "format": "png"}}] + + +# ---------- direct image blocks --------------------------------------------- +def test_direct_image_block_data_url(): + msg = chunk(content_blocks=[{"type": "image", "url": "data:image/jpeg;base64,JJJ"}]) + events = list(iter_message_events(msg)) + assert events == [{"kind": "image", "image": {"base64": "JJJ", "format": "jpeg"}}] + + +def test_direct_image_block_raw_base64(): + msg = chunk(content_blocks=[{"type": "image", "base64": "BBB", "format": "png"}]) + events = list(iter_message_events(msg)) + assert events == [{"kind": "image", "image": {"base64": "BBB", "format": "png"}}] + + +# ---------- tool messages with image payload -------------------------------- +def test_tool_message_with_json_image_string(): + msg = chunk( + type="tool", + content='{"type": "image", "base64": "TTT", "format": "png"}', + ) + events = list(iter_message_events(msg)) + assert events == [{"kind": "image", "image": {"base64": "TTT", "format": "png"}}] + + +def test_tool_message_text_is_ignored(): + msg = chunk(type="tool", content="some plain text result") + assert list(iter_message_events(msg)) == [] + + +# ---------- tool_status_for -------------------------------------------------- +def test_tool_status_known_names(): + assert tool_status_for(["semantic_search_products"]).startswith("🔍") + assert tool_status_for(["execute_sales_query"]).startswith("🔍") + assert tool_status_for(["web_search_preview"]).startswith("🔎") + assert tool_status_for(["code_interpreter"]).startswith("💻") + assert tool_status_for(["image_generation"]).startswith("🎨") + assert tool_status_for(["mystery_tool"]).startswith("⚙️") + assert tool_status_for([]).startswith("⚙️") diff --git a/azure.yaml b/azure.yaml index 99ca759..e0d5864 100644 --- a/azure.yaml +++ b/azure.yaml @@ -31,4 +31,7 @@ hooks: shell: sh run: ./infra/hooks/postprovision.sh interactive: true - continueOnError: false + # Database seeding is idempotent and not required for infra to be valid. + # If seeding fails (e.g. local pip/firewall issue), provisioning still succeeds + # and the user can re-run `azd hooks run postprovision` after fixing the issue. + continueOnError: true diff --git a/infra/hooks/postprovision.sh b/infra/hooks/postprovision.sh index 43be578..e552d0d 100755 --- a/infra/hooks/postprovision.sh +++ b/infra/hooks/postprovision.sh @@ -1,82 +1,113 @@ #!/bin/bash -set -e +# Post-provision hook: seed the Postgres database with sample data. +# +# This runs on the developer's machine after `azd up` finishes provisioning. +# It is intentionally tolerant of local environment quirks so that +# infrastructure provisioning is never marked as "failed" because of a +# seeding hiccup. If anything here fails, you can re-run it with: +# +# azd hooks run postprovision +# +set -uo pipefail echo "Running post-provision setup..." -# Get environment values -echo "Retrieving environment values..." -POSTGRES_URL=$(azd env get-values --output json | jq -r '.POSTGRES_URL // empty') +# ---- Resolve env ---------------------------------------------------------- +ENV_VALUES=$(azd env get-values --output json 2>/dev/null || echo '{}') +POSTGRES_URL=$(printf '%s' "$ENV_VALUES" | jq -r '.POSTGRES_URL // empty') +POSTGRES_HOST=$(printf '%s' "$ENV_VALUES" | jq -r '.POSTGRES_HOST // empty') +POSTGRES_RG=$(printf '%s' "$ENV_VALUES" | jq -r '.AZURE_RESOURCE_GROUP // empty') +AGENT_URL=$(printf '%s' "$ENV_VALUES" | jq -r '.AGENT_URL // empty') +MCP_SERVER_URL=$(printf '%s' "$ENV_VALUES" | jq -r '.MCP_SERVER_URL // empty') -# Populate database with sales data -if [ -n "$POSTGRES_URL" ]; then - echo "" - echo "📊 Populating database with sales data..." - - # Check if data files exist - if [ ! -f "data/product_data.json" ] || [ ! -f "data/reference_data.json" ]; then - echo "⚠️ Warning: Data files not found. Downloading..." - echo "" - echo "Downloading product_data.json (~2-3 GB, this may take a few minutes)..." - curl -L --progress-bar \ - "https://raw.githubusercontent.com/microsoft/aitour26-WRK540-unlock-your-agents-potential-with-model-context-protocol/main/data/database/product_data.json" \ - -o data/product_data.json - - echo "Downloading reference_data.json..." - curl -L --progress-bar \ - "https://raw.githubusercontent.com/microsoft/aitour26-WRK540-unlock-your-agents-potential-with-model-context-protocol/main/data/database/reference_data.json" \ - -o data/reference_data.json +# ---- Seed database -------------------------------------------------------- +if [ -z "$POSTGRES_URL" ]; then + echo "⚠️ POSTGRES_URL not found - skipping database population" +elif [ ! -f "data/product_data.json" ] || [ ! -f "data/reference_data.json" ]; then + echo "⚠️ Required data files (data/product_data.json, data/reference_data.json) missing." + echo " These ship with the repo - check that you cloned a complete copy." +else + echo "📊 Populating database with sample data..." + + # Add this machine's public IP to the Postgres firewall so the seed script + # can connect. Idempotent: ignore "already exists". + if [ -n "$POSTGRES_HOST" ] && [ -n "$POSTGRES_RG" ]; then + PG_SERVER_NAME="${POSTGRES_HOST%%.*}" + CLIENT_IP=$(curl -fsS https://api.ipify.org 2>/dev/null || true) + if [ -n "$CLIENT_IP" ]; then + echo " Adding firewall rule for $CLIENT_IP on $PG_SERVER_NAME..." + az postgres flexible-server firewall-rule create \ + --resource-group "$POSTGRES_RG" \ + --name "$PG_SERVER_NAME" \ + --rule-name "azd-postprovision-$(date +%Y%m%d)" \ + --start-ip-address "$CLIENT_IP" \ + --end-ip-address "$CLIENT_IP" \ + --only-show-errors >/dev/null 2>&1 || echo " (firewall rule already present, continuing)" + fi fi - - # Install dependencies if needed - if ! python3 -c "import asyncpg" 2>/dev/null; then - echo "Installing required Python packages..." - pip install -q asyncpg + + # Pick a working Python interpreter and install asyncpg into a venv if needed. + # We use a venv to avoid PEP-668 "externally managed environment" failures on + # macOS Homebrew / Debian-packaged Python. + PY=$(command -v python3 || command -v python || true) + if [ -z "$PY" ]; then + echo "❌ python3 is not on PATH - install Python 3.10+ then re-run: azd hooks run postprovision" + exit 0 # don't block provisioning fi - - # Run database generation script - echo "Running database generation script..." - export POSTGRES_URL="$POSTGRES_URL" - python3 data/generate_database.py - - echo "✅ Database populated successfully!" -else - echo "⚠️ POSTGRES_URL not found - skipping database population" -fi -echo "" -echo "✅ Post-provision setup complete!" -echo "" + VENV_DIR=".azd-postprovision-venv" + if [ ! -d "$VENV_DIR" ]; then + echo " Creating ephemeral venv for seeding..." + "$PY" -m venv "$VENV_DIR" || { + echo "❌ Could not create venv. Install python3-venv and re-run: azd hooks run postprovision" + exit 0 + } + fi + # shellcheck disable=SC1091 + . "$VENV_DIR/bin/activate" -# Get service URLs from azd environment -AGENT_URL=$(azd env get-values --output json | jq -r '.AGENT_URL // empty') -MCP_SERVER_URL=$(azd env get-values --output json | jq -r '.MCP_SERVER_URL // empty') + if ! python -c "import asyncpg, pgvector" 2>/dev/null; then + echo " Installing asyncpg + pgvector..." + python -m pip install --quiet --disable-pip-version-check asyncpg pgvector || { + echo "❌ pip install failed. Re-run later with: azd hooks run postprovision" + exit 0 + } + fi + + echo " Running data/generate_database.py..." + POSTGRES_URL="$POSTGRES_URL" python data/generate_database.py && \ + echo "✅ Database populated successfully!" || \ + echo "⚠️ Seeding failed. Fix the issue and re-run: azd hooks run postprovision" + deactivate || true +fi + +# ---- Print summary -------------------------------------------------------- +echo "" if [ -n "$AGENT_URL" ]; then - echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" - echo "🚀 Your LangChain Agent is Ready!" - echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" - echo "" - echo "🌐 WEB CHAT INTERFACE (Open in browser):" - echo " ${AGENT_URL}/" - echo "" - echo "📊 API ENDPOINTS:" - echo " Chat API: ${AGENT_URL}/api/chat (POST with JSON)" - echo " Health Check: ${AGENT_URL}/api/health" + cat < dict: - """Parse PostgreSQL URL into connection parameters.""" - # Parse pattern: postgresql://user:password@host:port/database?params - match = re.match( - r"postgresql://([^:]+):([^@]+)@([^:]+):(\d+)/([^?]+)(\?(.+))?", url - ) - if match: - user, password, host, port, database, _, params = match.groups() - result = { - "user": user, - "password": password, - "host": host, - "port": int(port), - "database": database, - } - # Parse query parameters - if params: - for param in params.split("&"): - if "=" in param: - key, value = param.split("=", 1) - if key == "sslmode": - result["ssl"] = value - return result - raise ValueError(f"Invalid PostgreSQL URL format: {url}") +_DSN_RE = re.compile(r"^([^:]+://)([^:/@]+):([^@]+)@(.+)$") + + +def _safe_dsn(url: str) -> str: + """Re-encode userinfo so asyncpg can parse DSNs with special chars in passwords. + + `azd` provisions Postgres with a generated password that may contain `#`, + `@`, `/`, etc. urlparse treats `#` as a fragment delimiter, so we extract + user/password with a tolerant regex, then percent-encode each. + """ + m = _DSN_RE.match(url) + if not m: + return url + scheme, user, password, rest = m.groups() + return f"{scheme}{quote(user, safe='')}:{quote(password, safe='')}@{rest}" class PostgreSQLProvider: - """PostgreSQL database provider with pgvector support.""" + """PostgreSQL connection pool wrapper with pgvector support.""" - def __init__(self, connection_url: str): - self.connection_params = parse_postgres_url(connection_url) + def __init__(self, dsn: str): + # asyncpg accepts a DSN string directly; we just URL-encode userinfo + # to handle generated passwords with `#`, `@`, etc. + self.dsn = _safe_dsn(dsn) self.pool: Optional[asyncpg.Pool] = None async def connect(self): - """Create connection pool.""" try: - self.pool = await asyncpg.create_pool( - **self.connection_params, min_size=1, max_size=10 - ) + self.pool = await asyncpg.create_pool(dsn=self.dsn, min_size=1, max_size=10) logger.info("✅ PostgreSQL connection pool established") except Exception as e: logger.error(f"❌ Failed to connect to PostgreSQL: {e}") raise async def close(self): - """Close connection pool.""" if self.pool: await self.pool.close() logger.info("PostgreSQL connection pool closed") async def execute_query(self, query: str) -> list[dict]: - """Execute a SELECT query and return results.""" if not self.pool: await self.connect() - async with self.pool.acquire() as conn: rows = await conn.fetch(query) return [dict(row) for row in rows] async def get_table_schemas(self) -> str: - """Get schema information for all tables in the retail schema.""" if not self.pool: await self.connect() schema_query = """ - SELECT + SELECT table_name, column_name, data_type, @@ -113,15 +102,9 @@ async def get_table_schemas(self) -> str: async with self.pool.acquire() as conn: rows = await conn.fetch(schema_query) - - # Group by table - tables = {} + tables: dict[str, list] = {} for row in rows: - table_name = row["table_name"] - if table_name not in tables: - tables[table_name] = [] - - tables[table_name].append( + tables.setdefault(row["table_name"], []).append( { "column": row["column_name"], "type": row["data_type"], @@ -129,32 +112,32 @@ async def get_table_schemas(self) -> str: "default": row["column_default"], } ) - return json.dumps(tables, indent=2) class SemanticSearchEmbedding: - """Semantic search using Azure OpenAI embeddings and pgvector.""" + """Semantic search using Azure OpenAI embeddings + pgvector cosine.""" - def __init__(self, openai_endpoint: str, embedding_deployment: str): + def __init__(self, openai_endpoint: str, embedding_deployment: str, api_version: str): self.openai_endpoint = openai_endpoint self.embedding_deployment = embedding_deployment - # Initialize Azure OpenAI async client with Entra ID auth credential = DefaultAzureCredential() token_provider = get_bearer_token_provider( credential, "https://cognitiveservices.azure.com/.default" ) self.client = AsyncAzureOpenAI( - api_version="2024-10-21", + api_version=api_version, azure_endpoint=openai_endpoint, azure_ad_token_provider=token_provider, ) - logger.info(f"✅ Azure OpenAI async client initialized: {openai_endpoint}") + logger.info( + "✅ Azure OpenAI client initialized (endpoint=%s, deployment=%s, api_version=%s)", + openai_endpoint, embedding_deployment, api_version, + ) async def get_embedding(self, text: str) -> list[float]: - """Get embedding vector for text asynchronously.""" response = await self.client.embeddings.create( input=text, model=self.embedding_deployment ) @@ -165,32 +148,25 @@ async def search_products( query: str, max_rows: int = 5, threshold: float = 0.7, - db_pool: asyncpg.Pool = None, - ctx: Context = None, + db_pool: Optional[asyncpg.Pool] = None, + ctx: Optional[Context] = None, ) -> str: - """Search for products using semantic similarity.""" if not db_pool: raise ToolError("Database not connected") - # Report progress: Getting embedding if ctx: await ctx.report_progress(progress=1, total=3) await ctx.info(f"Getting embedding for query: {query[:50]}...") - # Get embedding for query (1536-dim from text-embedding-ada-002) query_embedding = await self.get_embedding(query) - - # Convert embedding list to pgvector string format embedding_str = "[" + ",".join(map(str, query_embedding)) + "]" - # Report progress: Searching database if ctx: await ctx.report_progress(progress=2, total=3) await ctx.info("Searching products in database...") - # Search using pgvector cosine similarity on description embeddings search_query = """ - SELECT + SELECT p.product_name, p.product_description, c.category_name, @@ -206,198 +182,119 @@ async def search_products( async with db_pool.acquire() as conn: rows = await conn.fetch(search_query, embedding_str, threshold, max_rows) - - # Report progress: Done if ctx: await ctx.report_progress(progress=3, total=3) if not rows: return f"No products found matching '{query}' with similarity > {threshold}" - results = [] - for row in rows: - results.append( - f"• {row['product_name']} ({row['category_name']}) - " - f"${row['base_price']:.2f} - Similarity: {row['similarity']:.2%}\n" - f" {row['product_description'][:100]}..." - ) - - return "\n\n".join(results) + return "\n\n".join( + f"• {row['product_name']} ({row['category_name']}) - " + f"${row['base_price']:.2f} - Similarity: {row['similarity']:.2%}\n" + f" {row['product_description'][:100]}..." + for row in rows + ) @asynccontextmanager async def lifespan(mcp_server: FastMCP): - """Lifespan context manager for proper initialization and cleanup.""" + """Initialise DB + embedding providers; tear down on shutdown.""" global db_provider, embedding_provider logger.info("🚀 Starting MCP server initialization...") - # Initialize PostgreSQL provider postgres_url = os.getenv("POSTGRES_URL") if postgres_url: db_provider = PostgreSQLProvider(postgres_url) await db_provider.connect() - logger.info("✅ Database provider connected") else: logger.warning("⚠️ POSTGRES_URL not set - database tools will not work") - db_provider = None - # Initialize embedding provider openai_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT") - embedding_deployment = os.getenv( - "AZURE_OPENAI_EMBEDDING_DEPLOYMENT", "text-embedding-ada-002" - ) + embedding_deployment = os.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT", "text-embedding-3-small") + api_version = os.getenv("AZURE_OPENAI_API_VERSION", "2024-12-01-preview") if openai_endpoint: try: embedding_provider = SemanticSearchEmbedding( - openai_endpoint, embedding_deployment + openai_endpoint, embedding_deployment, api_version ) - logger.info("✅ Embedding provider initialized") except Exception as e: logger.error(f"Failed to initialize embeddings: {e}") - embedding_provider = None else: - logger.warning( - "⚠️ AZURE_OPENAI_ENDPOINT not set - semantic search will not work" - ) - embedding_provider = None + logger.warning("⚠️ AZURE_OPENAI_ENDPOINT not set - semantic search will not work") logger.info("✅ MCP server ready") + try: + yield + finally: + logger.info("🛑 Shutting down MCP server...") + if db_provider: + await db_provider.close() - yield # Server is running - # Cleanup on shutdown - logger.info("🛑 Shutting down MCP server...") - if db_provider: - await db_provider.close() - logger.info("✅ MCP server shutdown complete") +mcp = FastMCP("Zava Sales Analysis Tools", lifespan=lifespan) -# Create MCP server instance with lifespan -mcp = FastMCP("Zava Sales Analysis Tools", lifespan=lifespan) +# Defence-in-depth SQL filter. The proper hardening is a read-only Postgres +# role (no INSERT/UPDATE/DELETE/DDL grants) configured in Bicep. This filter +# remains a belt-and-braces guard: a deny-list will not catch every clever +# bypass on its own, but combined with a read-only role it adds friction +# against accidental destructive queries. +_SQL_FORBIDDEN = ( + "--", "/*", + "DROP ", "DELETE ", "INSERT ", "UPDATE ", "ALTER ", "CREATE ", + "TRUNCATE ", "GRANT ", "REVOKE ", "EXEC ", "EXECUTE ", "MERGE ", + "CALL ", "COPY ", +) def validate_sql_query(query: str) -> None: - """Validate SQL query for safety. - - Raises: - ToolError: If the query is not safe to execute - """ - # Strip whitespace and trailing semicolon (trailing semicolon is safe) + """Raise ToolError unless `query` is a single SELECT with no banned patterns.""" normalized = query.strip() if normalized.endswith(";"): normalized = normalized[:-1].strip() + upper = normalized.upper() - normalized_upper = normalized.upper() - - # Must start with SELECT - if not normalized_upper.startswith("SELECT"): + if not (upper.startswith("SELECT") or upper.startswith("WITH")): raise ToolError("Only SELECT queries are allowed") - - # Check for multiple statements (semicolon in the middle of query) if ";" in normalized: raise ToolError("Multiple SQL statements are not allowed") + for pat in _SQL_FORBIDDEN: + if pat in upper: + raise ToolError(f"Query contains forbidden pattern: {pat.strip()}") - # Check for dangerous patterns (SQL injection prevention) - dangerous_patterns = [ - "--", # SQL comments - "/*", # Block comments - "DROP ", - "DELETE ", - "INSERT ", - "UPDATE ", - "ALTER ", - "CREATE ", - "TRUNCATE ", - "GRANT ", - "REVOKE ", - "EXEC ", - "EXECUTE ", - ] - - for pattern in dangerous_patterns: - if pattern in normalized_upper: - raise ToolError(f"Query contains forbidden pattern: {pattern.strip()}") - - -# MCP Tools - - -@mcp.tool( - annotations={ - "title": "Get Current UTC Date", - "readOnlyHint": True, - "openWorldHint": False, - } -) -def get_current_utc_date() -> str: - """Get the current UTC date and time. - Returns: - Current UTC timestamp in ISO format - """ +# ---- MCP Tools ------------------------------------------------------------- +@mcp.tool(annotations={"title": "Get Current UTC Date", "readOnlyHint": True, "openWorldHint": False}) +def get_current_utc_date() -> str: + """Return the current UTC timestamp in ISO 8601 format.""" return datetime.now(timezone.utc).isoformat() -@mcp.tool( - annotations={ - "title": "Get Database Table Schemas", - "readOnlyHint": True, - "openWorldHint": False, - } -) +@mcp.tool(annotations={"title": "Get Database Table Schemas", "readOnlyHint": True, "openWorldHint": False}) async def get_table_schemas(ctx: Context) -> str: - """Get the schema information for all database tables. - - Returns: - JSON string containing table schemas with columns, types, and constraints - """ + """Return JSON describing the columns of every table in the `retail` schema.""" if not db_provider: - raise ToolError( - "Database not configured. Set POSTGRES_URL environment variable." - ) - + raise ToolError("Database not configured. Set POSTGRES_URL environment variable.") try: await ctx.info("Fetching database table schemas...") return await db_provider.get_table_schemas() except Exception as e: await ctx.error(f"Error getting schemas: {e}") - raise ToolError(f"Failed to get table schemas: {str(e)}") + raise ToolError(f"Failed to get table schemas: {e}") -@mcp.tool( - annotations={ - "title": "Execute Sales Query", - "readOnlyHint": True, - "openWorldHint": False, - } -) +@mcp.tool(annotations={"title": "Execute Sales Query", "readOnlyHint": True, "openWorldHint": False}) async def execute_sales_query( - query: Annotated[ - str, - Field( - description="SQL query to execute against the sales database. All tables are in the 'retail' schema." - ), - ], + query: Annotated[str, Field(description="SQL SELECT query against the 'retail' schema.")], ctx: Context, ) -> str: - """Execute a SQL query against the sales database. - - Args: - query: SQL query to execute (SELECT statements only). All tables are in the 'retail' schema. - - Returns: - JSON string containing query results - """ + """Execute a read-only SQL query and return the results as JSON.""" if not db_provider: - raise ToolError( - "Database not configured. Set POSTGRES_URL environment variable." - ) + raise ToolError("Database not configured. Set POSTGRES_URL environment variable.") - # Validate query for security validate_sql_query(query) - try: await ctx.info(f"Executing query: {query[:100]}...") results = await db_provider.execute_query(query) @@ -405,45 +302,21 @@ async def execute_sales_query( return json.dumps(results, indent=2, default=str) except Exception as e: await ctx.error(f"Error executing query: {e}") - raise ToolError(f"Query execution failed: {str(e)}") + raise ToolError(f"Query execution failed: {e}") -@mcp.tool( - annotations={ - "title": "Semantic Product Search", - "readOnlyHint": True, - "openWorldHint": True, # Calls Azure OpenAI API - } -) +@mcp.tool(annotations={"title": "Semantic Product Search", "readOnlyHint": True, "openWorldHint": True}) async def semantic_search_products( query: Annotated[str, Field(description="Search query to find relevant products")], ctx: Context, - max_rows: Annotated[ - int, Field(description="Maximum number of results to return", ge=1, le=20) - ] = 5, - threshold: Annotated[ - float, Field(description="Similarity threshold (0-1)", ge=0, le=1) - ] = 0.7, + max_rows: Annotated[int, Field(description="Maximum results", ge=1, le=20)] = 5, + threshold: Annotated[float, Field(description="Similarity threshold (0-1)", ge=0, le=1)] = 0.7, ) -> str: - """Search for products using semantic similarity with pgvector. - - Args: - query: Natural language search query - max_rows: Maximum number of results (1-20) - threshold: Minimum similarity score (0-1) - - Returns: - Formatted list of matching products with similarity scores - """ + """Find products by semantic similarity using pgvector cosine distance.""" if not embedding_provider: - raise ToolError( - "Semantic search not configured. Set AZURE_OPENAI_ENDPOINT environment variable." - ) - + raise ToolError("Semantic search not configured. Set AZURE_OPENAI_ENDPOINT.") if not db_provider or not db_provider.pool: - raise ToolError( - "Database not connected. Set POSTGRES_URL environment variable." - ) + raise ToolError("Database not connected. Set POSTGRES_URL.") try: return await embedding_provider.search_products( @@ -453,18 +326,17 @@ async def semantic_search_products( raise except Exception as e: await ctx.error(f"Error in semantic search: {e}") - raise ToolError(f"Semantic search failed: {str(e)}") + raise ToolError(f"Semantic search failed: {e}") -# Create the Starlette app (using http_app instead of deprecated streamable_http_app) +# Streamable-HTTP ASGI app app = mcp.http_app() def run(): - """Run the server with uvicorn.""" - port = int(os.getenv("PORT", "8000")) - uvicorn.run(app, host="0.0.0.0", port=port) + uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", "8000"))) if __name__ == "__main__": run() + diff --git a/mcp/requirements.txt b/mcp/requirements.txt index 7ad474f..3dc41da 100644 --- a/mcp/requirements.txt +++ b/mcp/requirements.txt @@ -1,4 +1,4 @@ -fastmcp>=0.3.0 +fastmcp>=2.0,<3 pydantic>=2.0.0 python-dotenv>=1.0.0 asyncpg>=0.29.0 diff --git a/mcp/test_app.py b/mcp/test_app.py new file mode 100644 index 0000000..7649629 --- /dev/null +++ b/mcp/test_app.py @@ -0,0 +1,94 @@ +"""Unit tests for mcp.app helpers (no network).""" + +from __future__ import annotations + +import sys +from pathlib import Path + +import pytest + +sys.path.insert(0, str(Path(__file__).parent)) + +from app import _safe_dsn, validate_sql_query # noqa: E402 +from fastmcp.exceptions import ToolError # noqa: E402 + + +# ---- _safe_dsn -------------------------------------------------------------- +def test_safe_dsn_passthrough_when_no_special_chars(): + url = "postgresql://user:pass@host:5432/db?sslmode=require" + assert _safe_dsn(url) == url + + +def test_safe_dsn_encodes_password_with_hash(): + url = "postgresql://user:pa#ss@host:5432/db" + out = _safe_dsn(url) + assert "pa%23ss" in out + assert out.endswith("@host:5432/db") + + +def test_safe_dsn_preserves_query(): + url = "postgresql://u:p%23w@host:5432/db?sslmode=require" + assert _safe_dsn(url).endswith("?sslmode=require") + + +def test_safe_dsn_no_userinfo(): + url = "postgresql://host:5432/db" + assert _safe_dsn(url) == url + + +# ---- accepted --------------------------------------------------------------- +def test_simple_select_passes(): + validate_sql_query("SELECT 1") + + +def test_select_with_trailing_semicolon_passes(): + validate_sql_query("SELECT 1;") + + +def test_with_cte_passes(): + validate_sql_query("WITH x AS (SELECT 1) SELECT * FROM x") + + +# ---- rejected --------------------------------------------------------------- +@pytest.mark.parametrize( + "query", + [ + "DROP TABLE products", + "DELETE FROM products", + "INSERT INTO products VALUES (1)", + "UPDATE products SET price = 0", + "ALTER TABLE products ADD col INT", + "CREATE TABLE foo (id INT)", + "TRUNCATE products", + "GRANT SELECT ON products TO public", + "REVOKE SELECT ON products FROM public", + "EXEC sp_evil", + "EXECUTE sp_evil", + "MERGE INTO products USING staging ON 1=1", + "CALL evil_proc()", + "COPY products TO '/tmp/x'", + ], +) +def test_destructive_statements_rejected(query): + with pytest.raises(ToolError): + validate_sql_query(query) + + +def test_multiple_statements_rejected(): + with pytest.raises(ToolError): + validate_sql_query("SELECT 1; SELECT 2") + + +def test_comment_rejected(): + with pytest.raises(ToolError): + validate_sql_query("SELECT 1 -- evil") + + +def test_block_comment_rejected(): + with pytest.raises(ToolError): + validate_sql_query("SELECT /* evil */ 1") + + +def test_non_select_keyword_rejected(): + with pytest.raises(ToolError): + validate_sql_query("EXPLAIN SELECT 1")