diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..1cc3221 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,123 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + # ----------------------------------------------------------------------- + # Unit tests — no external services (fakeredis + SQLite) + # ----------------------------------------------------------------------- + + test: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.12", "3.13"] + + steps: + - uses: actions/checkout@v4 + + - name: Install uv + uses: astral-sh/setup-uv@v6 + with: + enable-cache: true + + - name: Set up Python ${{ matrix.python-version }} + run: uv python install ${{ matrix.python-version }} + + - name: Install dependencies + run: uv sync --dev + + - name: Run unit tests + run: | + uv run pytest tests/ \ + --ignore=tests/test_kafka_integration.py \ + -o "addopts=" \ + -v --tb=long + + # ----------------------------------------------------------------------- + # Kafka integration tests — real broker via docker run + # ----------------------------------------------------------------------- + test-kafka: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Start Kafka broker + run: | + docker run -d --name kafka \ + -p 9092:9092 \ + -e KAFKA_NODE_ID=1 \ + -e KAFKA_PROCESS_ROLES=broker,controller \ + -e KAFKA_CONTROLLER_QUORUM_VOTERS=1@localhost:9093 \ + -e KAFKA_CONTROLLER_LISTENER_NAMES=CONTROLLER \ + -e KAFKA_LISTENERS=PLAINTEXT://:9092,CONTROLLER://:9093 \ + -e KAFKA_ADVERTISED_LISTENERS=PLAINTEXT://localhost:9092 \ + -e KAFKA_LISTENER_SECURITY_PROTOCOL_MAP=PLAINTEXT:PLAINTEXT,CONTROLLER:PLAINTEXT \ + -e KAFKA_INTER_BROKER_LISTENER_NAME=PLAINTEXT \ + -e KAFKA_LOG_CLEANER_MIN_COMPACTION_LAG_MS=0 \ + -e KAFKA_LOG_CLEANER_MIN_CLEANABLE_RATIO=0.01 \ + -e KAFKA_LOG_RETENTION_MS=60000 \ + -e KAFKA_NUM_PARTITIONS=1 \ + -e KAFKA_AUTO_CREATE_TOPICS_ENABLE=true \ + -e KAFKA_GROUP_INITIAL_REBALANCE_DELAY_MS=0 \ + -e CLUSTER_ID=ciTestCluster0001 \ + apache/kafka:3.9.0 + + - name: Install uv + uses: astral-sh/setup-uv@v6 + with: + enable-cache: true + + - name: Set up Python + run: uv python install 3.12 + + - name: Install dependencies + run: uv sync --dev --extra kafka + + - name: Wait for Kafka to be ready + run: | + echo "Waiting for Kafka..." + for i in $(seq 1 30); do + if nc -z localhost 9092 2>/dev/null; then + echo "Kafka port is open" + sleep 5 + echo "Kafka is ready" + exit 0 + fi + echo " attempt $i/30..." + sleep 2 + done + echo "Kafka failed to start" + docker logs kafka + exit 1 + + - name: Run Kafka integration tests + run: | + uv run pytest tests/test_kafka_integration.py \ + -o "addopts=" \ + -v --tb=long 2>&1 | tee /tmp/kafka_test_output.txt + exit ${PIPESTATUS[0]} + env: + AGENTEXEC_STATE_BACKEND: agentexec.state.kafka + KAFKA_BOOTSTRAP_SERVERS: localhost:9092 + AGENTEXEC_KAFKA_DEFAULT_PARTITIONS: "2" + AGENTEXEC_KAFKA_REPLICATION_FACTOR: "1" + + - name: Show Kafka logs on failure + if: failure() + run: docker logs kafka 2>&1 | tail -50 + + - name: Create failure check annotation with output + if: failure() + run: | + if [ -f /tmp/kafka_test_output.txt ]; then + grep -E '\[queue_|FAILED|ERROR|AssertionError|TIMEOUT|short test summary' /tmp/kafka_test_output.txt | tail -9 | while IFS= read -r line; do + echo "::warning::$line" + done + fi diff --git a/docker-compose.kafka.yml b/docker-compose.kafka.yml new file mode 100644 index 0000000..0080d51 --- /dev/null +++ b/docker-compose.kafka.yml @@ -0,0 +1,48 @@ +# Kafka development environment for running integration tests locally. +# +# Usage: +# docker compose -f docker-compose.kafka.yml up -d +# +# KAFKA_BOOTSTRAP_SERVERS=localhost:9092 \ +# AGENTEXEC_STATE_BACKEND=agentexec.state.kafka \ +# uv run pytest tests/test_kafka_integration.py -v +# +# docker compose -f docker-compose.kafka.yml down +# +# Kafka UI available at http://localhost:8080 + +services: + kafka: + image: apache/kafka:3.9.0 + ports: + - "9092:9092" + environment: + KAFKA_NODE_ID: "1" + KAFKA_PROCESS_ROLES: broker,controller + KAFKA_CONTROLLER_QUORUM_VOTERS: 1@localhost:9093 + KAFKA_CONTROLLER_LISTENER_NAMES: CONTROLLER + KAFKA_LISTENERS: PLAINTEXT://:9092,CONTROLLER://:9093 + KAFKA_ADVERTISED_LISTENERS: PLAINTEXT://localhost:9092 + KAFKA_LISTENER_SECURITY_PROTOCOL_MAP: PLAINTEXT:PLAINTEXT,CONTROLLER:PLAINTEXT + KAFKA_INTER_BROKER_LISTENER_NAME: PLAINTEXT + CLUSTER_ID: "agentexec-dev-cluster-01" + KAFKA_OFFSETS_TOPIC_REPLICATION_FACTOR: "1" + KAFKA_TRANSACTION_STATE_LOG_REPLICATION_FACTOR: "1" + KAFKA_TRANSACTION_STATE_LOG_MIN_ISR: "1" + healthcheck: + test: /opt/kafka/bin/kafka-topics.sh --bootstrap-server localhost:9092 --list + interval: 5s + timeout: 10s + retries: 15 + start_period: 15s + + kafka-ui: + image: provectuslabs/kafka-ui:latest + ports: + - "8080:8080" + environment: + KAFKA_CLUSTERS_0_NAME: agentexec + KAFKA_CLUSTERS_0_BOOTSTRAPSERVERS: kafka:9092 + depends_on: + kafka: + condition: service_healthy diff --git a/examples/queue-fairness/run.py b/examples/queue-fairness/run.py new file mode 100644 index 0000000..edcc8ff --- /dev/null +++ b/examples/queue-fairness/run.py @@ -0,0 +1,187 @@ +"""Queue fairness test. + +Validates that tasks distributed across many partition queues get +roughly equal treatment under the scan-based dequeue strategy. + +Usage: + uv run python examples/queue-fairness/run.py + uv run python examples/queue-fairness/run.py --partitions 100 --tasks-per-partition 5 --workers 8 +""" + +from __future__ import annotations + +import argparse +import asyncio +import json +import statistics +import time +from uuid import UUID, uuid4 + +from pydantic import BaseModel + +import agentexec as ax +from agentexec.config import CONF +from agentexec.state import backend + + +class BenchContext(BaseModel): + partition_id: int + task_index: int + queued_at: float + + +async def enqueue_tasks(partitions: int, tasks_per_partition: int) -> int: + """Push tasks across N partitions with M tasks each.""" + total = 0 + for p in range(partitions): + partition_key = f"partition:{p}" + for t in range(tasks_per_partition): + task = ax.Task( + task_name="bench_task", + context={ + "partition_id": p, + "task_index": t, + "queued_at": time.time(), + }, + agent_id=uuid4(), + ) + await backend.queue.push( + CONF.queue_prefix, + task.model_dump_json(), + partition_key=partition_key, + ) + total += 1 + return total + + +async def worker( + worker_id: int, + results: list[dict], + stop_event: asyncio.Event, + work_duration: float, +): + """Simulated worker that pops tasks and records timing.""" + while not stop_event.is_set(): + data = await backend.queue.pop(CONF.queue_prefix, timeout=1) + if data is None: + # Check if we should stop + await asyncio.sleep(0.1) + continue + + picked_up_at = time.time() + context = data.get("context", {}) + queued_at = context.get("queued_at", picked_up_at) + wait_time = picked_up_at - queued_at + + results.append({ + "worker_id": worker_id, + "partition_id": context.get("partition_id"), + "task_index": context.get("task_index"), + "wait_time": wait_time, + "picked_up_at": picked_up_at, + }) + + # Simulate work + await asyncio.sleep(work_duration) + + # Release the partition lock + partition_key = f"partition:{context.get('partition_id')}" + await backend.queue.release_lock(CONF.queue_prefix, partition_key) + + +async def run( + partitions: int, + tasks_per_partition: int, + num_workers: int, + work_duration: float, +): + print(f"Enqueueing {partitions} partitions x {tasks_per_partition} tasks = {partitions * tasks_per_partition} total") + total = await enqueue_tasks(partitions, tasks_per_partition) + print(f"Enqueued {total} tasks") + + results: list[dict] = [] + stop_event = asyncio.Event() + + print(f"Starting {num_workers} workers (simulated work: {work_duration}s)") + start = time.time() + + workers = [ + asyncio.create_task(worker(i, results, stop_event, work_duration)) + for i in range(num_workers) + ] + + # Wait until all tasks are processed + while len(results) < total: + await asyncio.sleep(0.5) + elapsed = time.time() - start + print(f" {len(results)}/{total} tasks processed ({elapsed:.1f}s)", end="\r") + + elapsed = time.time() - start + stop_event.set() + + # Let workers drain + await asyncio.gather(*workers, return_exceptions=True) + + print(f"\n\nCompleted {len(results)} tasks in {elapsed:.1f}s") + print(f"Throughput: {len(results) / elapsed:.1f} tasks/sec") + + # Analyze fairness per partition + partition_times: dict[int, list[float]] = {} + for r in results: + pid = r["partition_id"] + if pid not in partition_times: + partition_times[pid] = [] + partition_times[pid].append(r["wait_time"]) + + avg_per_partition = { + pid: statistics.mean(times) for pid, times in partition_times.items() + } + + all_waits = [r["wait_time"] for r in results] + all_avgs = list(avg_per_partition.values()) + + print(f"\nWait time (seconds from enqueue to pickup):") + print(f" Overall mean: {statistics.mean(all_waits):.3f}s") + print(f" Overall median: {statistics.median(all_waits):.3f}s") + print(f" Overall stdev: {statistics.stdev(all_waits):.3f}s") + print(f" Min: {min(all_waits):.3f}s") + print(f" Max: {max(all_waits):.3f}s") + + print(f"\nFairness across {len(partition_times)} partitions:") + print(f" Mean of partition averages: {statistics.mean(all_avgs):.3f}s") + print(f" Stdev of partition averages: {statistics.stdev(all_avgs):.3f}s") + print(f" Min partition avg: {min(all_avgs):.3f}s") + print(f" Max partition avg: {max(all_avgs):.3f}s") + print(f" Spread (max-min): {max(all_avgs) - min(all_avgs):.3f}s") + + # Worker distribution + worker_counts: dict[int, int] = {} + for r in results: + wid = r["worker_id"] + worker_counts[wid] = worker_counts.get(wid, 0) + 1 + + print(f"\nWorker distribution:") + for wid in sorted(worker_counts): + print(f" Worker {wid}: {worker_counts[wid]} tasks") + + await backend.close() + + +def main(): + parser = argparse.ArgumentParser(description="Queue fairness benchmark") + parser.add_argument("--partitions", type=int, default=500, help="Number of partition queues") + parser.add_argument("--tasks-per-partition", type=int, default=12, help="Tasks per partition") + parser.add_argument("--workers", type=int, default=4, help="Number of concurrent workers") + parser.add_argument("--work-duration", type=float, default=0.5, help="Simulated work time (seconds)") + args = parser.parse_args() + + asyncio.run(run( + partitions=args.partitions, + tasks_per_partition=args.tasks_per_partition, + num_workers=args.workers, + work_duration=args.work_duration, + )) + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 42ab646..d92754d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,11 @@ dependencies = [ "croniter>=6.0.0", ] +[project.optional-dependencies] +kafka = [ + "aiokafka>=0.11.0", +] + [project.urls] Homepage = "https://github.com/Agent-CI/agentexec" diff --git a/src/agentexec/activity/__init__.py b/src/agentexec/activity/__init__.py index b47d7ae..f9e6e49 100644 --- a/src/agentexec/activity/__init__.py +++ b/src/agentexec/activity/__init__.py @@ -1,38 +1,98 @@ -from agentexec.activity.models import Activity, ActivityLog, Status +from agentexec.activity.models import Activity, ActivityLog +from agentexec.activity.status import Status from agentexec.activity.schemas import ( ActivityDetailSchema, ActivityListItemSchema, ActivityListSchema, ActivityLogSchema, ) -from agentexec.activity.tracker import ( +from agentexec.activity.handlers import ActivityHandler, PostgresHandler +from agentexec.activity.producer import ( create, update, complete, error, cancel_pending, - list, - detail, - count_active, + generate_agent_id, + normalize_agent_id, ) +handler: ActivityHandler = PostgresHandler() + +import uuid +from typing import Any + +from sqlalchemy.orm import Session + + +async def list( + session: Session | None = None, + page: int = 1, + page_size: int = 50, + metadata_filter: dict[str, Any] | None = None, +) -> ActivityListSchema: + """List activities with pagination.""" + from agentexec.core.db import get_session + + with session or get_session() as db: + query = db.query(Activity) + if metadata_filter: + for key, value in metadata_filter.items(): + query = query.filter(Activity.metadata_[key].as_string() == str(value)) + total = query.count() + + rows = Activity.get_list(db, page=page, page_size=page_size, metadata_filter=metadata_filter) + return ActivityListSchema( + items=[ActivityListItemSchema.model_validate(row) for row in rows], + total=total, + page=page, + page_size=page_size, + ) + + +async def detail( + session: Session | None = None, + agent_id: str | uuid.UUID | None = None, + metadata_filter: dict[str, Any] | None = None, +) -> ActivityDetailSchema | None: + """Get a single activity by agent_id.""" + from agentexec.core.db import get_session + + if agent_id is None: + return None + if isinstance(agent_id, str): + agent_id = uuid.UUID(agent_id) + + with session or get_session() as db: + item = Activity.get_by_agent_id(db, agent_id, metadata_filter=metadata_filter) + if item is not None: + return ActivityDetailSchema.model_validate(item) + return None + + +async def count_active(session: Session | None = None) -> int: + """Count active (queued or running) agents.""" + from agentexec.core.db import get_session + + with session or get_session() as db: + return Activity.get_active_count(db) + + __all__ = [ - # Models "Activity", "ActivityLog", "Status", - # Schemas "ActivityLogSchema", "ActivityDetailSchema", "ActivityListItemSchema", "ActivityListSchema", - # Lifecycle API "create", "update", "complete", "error", "cancel_pending", - # Query API + "generate_agent_id", + "normalize_agent_id", "list", "detail", "count_active", diff --git a/src/agentexec/activity/events.py b/src/agentexec/activity/events.py new file mode 100644 index 0000000..e041a9e --- /dev/null +++ b/src/agentexec/activity/events.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +import uuid +from typing import Any + +from pydantic import BaseModel + + +class ActivityCreated(BaseModel): + agent_id: uuid.UUID + task_name: str + message: str + metadata: dict[str, Any] | None = None + + +class ActivityUpdated(BaseModel): + agent_id: uuid.UUID + message: str + status: str + percentage: int | None = None + + +# Resolve forward references +ActivityCreated.model_rebuild() +ActivityUpdated.model_rebuild() diff --git a/src/agentexec/activity/handlers.py b/src/agentexec/activity/handlers.py new file mode 100644 index 0000000..74ccde4 --- /dev/null +++ b/src/agentexec/activity/handlers.py @@ -0,0 +1,103 @@ +"""Activity event handlers — pluggable persistence for lifecycle events. + +The activity system uses a handler pattern to decouple event emission from +persistence. Every call to ``activity.create()``, ``activity.update()``, etc. +emits a typed event (``ActivityCreated`` or ``ActivityUpdated``) and routes it +through ``activity.handler``, a callable that decides what to do with it. + +Two handlers are provided: + +- ``PostgresHandler`` (default): Writes events directly to Postgres via + SQLAlchemy. Used by API servers and the pool's main process. + +- ``IPCHandler``: Serializes events onto a ``multiprocessing.Queue`` for + the pool to receive and persist. Used by worker processes, which don't + have database access. + +The handler is swapped at process init time. Workers set the IPC handler +during startup; everything else uses the default Postgres handler:: + + # Worker process (set automatically by Pool) + from agentexec.activity.handlers import IPCHandler + activity.handler = IPCHandler(tx_queue) + + # API server or pool process (default, no setup needed) + await activity.update(agent_id, "Processing", percentage=50) + # → writes directly to Postgres + +Custom handlers can be implemented by conforming to the ``ActivityHandler`` +protocol — any callable that accepts ``ActivityCreated | ActivityUpdated``. +""" + +from __future__ import annotations + +import multiprocessing as mp +from typing import Protocol + +from agentexec.activity.events import ActivityCreated, ActivityUpdated +from agentexec.activity.status import Status + + +class ActivityHandler(Protocol): + """Protocol for activity event handlers. + + Any callable that accepts an ``ActivityCreated`` or ``ActivityUpdated`` + event satisfies this protocol. + """ + def __call__(self, event: ActivityCreated | ActivityUpdated) -> None: ... + + +class PostgresHandler: + """Writes activity events directly to Postgres. + + This is the default handler. It creates a short-lived database session + for each event, writes the appropriate records, and commits. + """ + + def __call__(self, event: ActivityCreated | ActivityUpdated) -> None: + match event: + case ActivityCreated(agent_id=agent_id, task_name=task_name, message=message, metadata=metadata): + from agentexec.activity.models import Activity, ActivityLog + from agentexec.core.db import get_session + + with get_session() as db: + record = Activity(agent_id=agent_id, agent_type=task_name, metadata_=metadata) + db.add(record) + db.flush() + db.add(ActivityLog( + activity_id=record.id, + message=message, + status=Status.QUEUED, + percentage=0, + )) + db.commit() + + case ActivityUpdated(agent_id=agent_id, message=message, status=status, percentage=percentage): + from agentexec.activity.models import Activity + from agentexec.core.db import get_session + + with get_session() as db: + Activity.append_log( + session=db, + agent_id=agent_id, + message=message, + status=Status(status), + percentage=percentage, + ) + + +class IPCHandler: + """Sends activity events to the pool via multiprocessing queue. + + Worker processes use this handler so they don't need database access. + Events are picked up by the pool's event loop and written to Postgres + using the default ``PostgresHandler``. + """ + + tx: mp.Queue + + def __init__(self, tx: mp.Queue) -> None: + self.tx = tx + + def __call__(self, event: ActivityCreated | ActivityUpdated) -> None: + self.tx.put_nowait(event) diff --git a/src/agentexec/activity/models.py b/src/agentexec/activity/models.py index 8d05565..e7aa853 100644 --- a/src/agentexec/activity/models.py +++ b/src/agentexec/activity/models.py @@ -1,5 +1,5 @@ from __future__ import annotations -from enum import Enum as PyEnum + import uuid from datetime import UTC, datetime @@ -22,20 +22,11 @@ from sqlalchemy.engine import RowMapping from sqlalchemy.orm import Mapped, Session, aliased, mapped_column, relationship, declared_attr +from agentexec.activity.status import Status from agentexec.config import CONF from agentexec.core.db import Base -class Status(str, PyEnum): - """Agent execution status.""" - - QUEUED = "queued" - RUNNING = "running" - COMPLETE = "complete" - ERROR = "error" - CANCELED = "canceled" - - class Activity(Base): """Tracks background agent execution sessions. diff --git a/src/agentexec/activity/producer.py b/src/agentexec/activity/producer.py new file mode 100644 index 0000000..29faded --- /dev/null +++ b/src/agentexec/activity/producer.py @@ -0,0 +1,179 @@ +"""Activity event producer — the public API for activity lifecycle. + +All activity methods emit typed events routed through ``activity.handler``. +By default, events are written directly to Postgres. In worker processes, +the handler is swapped to send events via IPC to the pool. + +See ``activity.handlers`` for the handler implementations. +""" + +from __future__ import annotations + +import uuid +from typing import Any + +from sqlalchemy.orm import Session + +import agentexec.activity as activity +from agentexec.activity.events import ActivityCreated, ActivityUpdated +from agentexec.activity.status import Status + + +def generate_agent_id() -> uuid.UUID: + """Generate a new UUID4 agent identifier.""" + return uuid.uuid4() + + +def normalize_agent_id(agent_id: str | uuid.UUID) -> uuid.UUID: + """Coerce a string or UUID to a UUID object.""" + if isinstance(agent_id, str): + return uuid.UUID(agent_id) + return agent_id + + + + +async def create( + task_name: str, + message: str = "Agent queued", + agent_id: str | uuid.UUID | None = None, + session: Session | None = None, + metadata: dict[str, Any] | None = None, +) -> uuid.UUID: + """Create a new activity record with an initial "queued" log entry. + + Called during ``ax.enqueue()`` to register the task in the activity + stream before it hits the queue. + + Args: + task_name: The registered task name (e.g. ``"research"``). + message: Initial log message. + agent_id: Optional pre-generated agent ID. Auto-generated if omitted. + session: Unused — kept for backwards compatibility. + metadata: Arbitrary key-value pairs attached to the activity + (e.g. ``{"organization_id": "org-123"}``). + + Returns: + The agent_id (UUID) of the created record. + + Example:: + + agent_id = await activity.create("research", metadata={"org": "acme"}) + """ + agent_id = normalize_agent_id(agent_id) if agent_id else generate_agent_id() + activity.handler(ActivityCreated( + agent_id=agent_id, + task_name=task_name, + message=message, + metadata=metadata, + )) + return agent_id + + +async def update( + agent_id: str | uuid.UUID, + message: str, + percentage: int | None = None, + status: Status | None = None, + session: Session | None = None, +) -> bool: + """Append a log entry to an existing activity record. + + Defaults to ``Status.RUNNING`` if no status is provided. + + Args: + agent_id: The agent to update. + message: Log message describing the current state. + percentage: Optional completion percentage (0-100). + status: Optional status override (default: ``RUNNING``). + session: Unused — kept for backwards compatibility. + + Example:: + + await activity.update(agent_id, "Fetching data", percentage=30) + """ + activity.handler(ActivityUpdated( + agent_id=normalize_agent_id(agent_id), + message=message, + status=(status or Status.RUNNING).value, + percentage=percentage, + )) + return True + + +async def complete( + agent_id: str | uuid.UUID, + message: str = "Agent completed", + percentage: int = 100, + session: Session | None = None, +) -> bool: + """Mark an activity as complete. + + Args: + agent_id: The agent to mark complete. + message: Completion log message. + percentage: Final percentage (default: 100). + session: Unused — kept for backwards compatibility. + + Example:: + + await activity.complete(agent_id) + """ + activity.handler(ActivityUpdated( + agent_id=normalize_agent_id(agent_id), + message=message, + status=Status.COMPLETE.value, + percentage=percentage, + )) + return True + + +async def error( + agent_id: str | uuid.UUID, + message: str = "Agent failed", + percentage: int = 100, + session: Session | None = None, +) -> bool: + """Mark an activity as failed. + + Args: + agent_id: The agent to mark as errored. + message: Error log message. + percentage: Final percentage (default: 100). + session: Unused — kept for backwards compatibility. + + Example:: + + await activity.error(agent_id, "Connection timeout") + """ + activity.handler(ActivityUpdated( + agent_id=normalize_agent_id(agent_id), + message=message, + status=Status.ERROR.value, + percentage=percentage, + )) + return True + + +async def cancel_pending(session: Session | None = None) -> int: + """Cancel all queued and running activities. + + Typically called during pool shutdown to mark in-flight tasks as + canceled. Reads pending IDs from Postgres and emits cancel events. + + Returns: + Number of activities canceled. + """ + from agentexec.activity.models import Activity + from agentexec.core.db import get_session + + with session or get_session() as db: + pending_ids = Activity.get_pending_ids(db) + for agent_id in pending_ids: + activity.handler(ActivityUpdated( + agent_id=agent_id, + message="Canceled due to shutdown", + status=Status.CANCELED.value, + percentage=None, + )) + return len(pending_ids) diff --git a/src/agentexec/activity/schemas.py b/src/agentexec/activity/schemas.py index e326348..144a73f 100644 --- a/src/agentexec/activity/schemas.py +++ b/src/agentexec/activity/schemas.py @@ -2,9 +2,9 @@ from datetime import datetime from typing import Any -from pydantic import BaseModel, ConfigDict, Field, computed_field +from pydantic import AliasChoices, BaseModel, ConfigDict, Field, computed_field -from agentexec.activity.models import Status +from agentexec.activity.status import Status class ActivityLogSchema(BaseModel): @@ -22,15 +22,19 @@ class ActivityLogSchema(BaseModel): class ActivityDetailSchema(BaseModel): """Schema for an agent activity record with optional logs.""" - model_config = ConfigDict(from_attributes=True) + model_config = ConfigDict(from_attributes=True, populate_by_name=True) - id: uuid.UUID + id: uuid.UUID | None = None agent_id: uuid.UUID agent_type: str created_at: datetime updated_at: datetime logs: list[ActivityLogSchema] = Field(default_factory=list) - metadata: dict[str, Any] | None = Field(default=None, alias="metadata_", exclude=True) + metadata: dict[str, Any] | None = Field( + default=None, + validation_alias=AliasChoices("metadata_", "metadata"), + exclude=True, + ) class ActivityListItemSchema(BaseModel): diff --git a/src/agentexec/activity/status.py b/src/agentexec/activity/status.py new file mode 100644 index 0000000..f17b522 --- /dev/null +++ b/src/agentexec/activity/status.py @@ -0,0 +1,9 @@ +from enum import Enum + + +class Status(str, Enum): + QUEUED = "queued" + RUNNING = "running" + COMPLETE = "complete" + ERROR = "error" + CANCELED = "canceled" diff --git a/src/agentexec/activity/tracker.py b/src/agentexec/activity/tracker.py deleted file mode 100644 index 12aeff8..0000000 --- a/src/agentexec/activity/tracker.py +++ /dev/null @@ -1,286 +0,0 @@ -import uuid -from typing import Any - -from sqlalchemy.orm import Session - -from agentexec.activity.models import Activity, ActivityLog, Status -from agentexec.activity.schemas import ( - ActivityDetailSchema, - ActivityListItemSchema, - ActivityListSchema, -) -from agentexec.core.db import get_global_session - - -def generate_agent_id() -> uuid.UUID: - """Generate a new UUID for an agent. - - This is the centralized function for generating agent IDs. - Users can override this if they need custom ID generation logic. - - Returns: - A new UUID4 object - """ - return uuid.uuid4() - - -def normalize_agent_id(agent_id: str | uuid.UUID) -> uuid.UUID: - """Normalize agent_id to UUID object. - - Args: - agent_id: Either a string UUID or UUID object - - Returns: - UUID object - - Raises: - ValueError: If string is not a valid UUID - """ - if isinstance(agent_id, str): - return uuid.UUID(agent_id) - return agent_id - - -def create( - task_name: str, - message: str = "Agent queued", - agent_id: str | uuid.UUID | None = None, - session: Session | None = None, - metadata: dict[str, Any] | None = None, -) -> uuid.UUID: - """Create a new agent activity record with initial queued status. - - Args: - task_name: Name/type of the task (e.g., "research", "analysis") - message: Initial log message (default: "Agent queued") - agent_id: Optional custom agent ID (string or UUID). If not provided, one will be auto-generated. - session: Optional SQLAlchemy session. If not provided, uses global session factory. - metadata: Optional dict of arbitrary metadata to attach to the activity. - Useful for multi-tenancy (e.g., {"organization_id": "org-123"}). - - Returns: - The agent_id (as UUID object) of the created record - """ - agent_id = normalize_agent_id(agent_id) if agent_id else generate_agent_id() - db = session or get_global_session() - - activity_record = Activity( - agent_id=agent_id, - agent_type=task_name, - metadata_=metadata, - ) - db.add(activity_record) - db.flush() - - log = ActivityLog( - activity_id=activity_record.id, - message=message, - status=Status.QUEUED, - percentage=0, - ) - db.add(log) - db.commit() - - return agent_id - - -def update( - agent_id: str | uuid.UUID, - message: str, - percentage: int | None = None, - status: Status | None = None, - session: Session | None = None, -) -> bool: - """Update an agent's activity by adding a new log message. - - This function will set the status to RUNNING unless a different status is explicitly provided. - - Args: - agent_id: The agent_id of the agent to update - message: Log message to append - percentage: Optional completion percentage (0-100) - status: Optional status to set (default: RUNNING) - session: Optional SQLAlchemy session. If not provided, uses global session factory. - - Returns: - True if successful - - Raises: - ValueError: If agent_id not found - """ - db = session or get_global_session() - - Activity.append_log( - session=db, - agent_id=normalize_agent_id(agent_id), - message=message, - status=status if status else Status.RUNNING, - percentage=percentage, - ) - return True - - -def complete( - agent_id: str | uuid.UUID, - message: str = "Agent completed", - percentage: int = 100, - session: Session | None = None, -) -> bool: - """Mark an agent activity as complete. - - Args: - agent_id: The agent_id of the agent to mark as complete - message: Log message (default: "Agent completed") - percentage: Completion percentage (default: 100) - session: Optional SQLAlchemy session. If not provided, uses global session factory. - - Returns: - True if successful - - Raises: - ValueError: If agent_id not found - """ - db = session or get_global_session() - - Activity.append_log( - session=db, - agent_id=normalize_agent_id(agent_id), - message=message, - status=Status.COMPLETE, - percentage=percentage, - ) - return True - - -def error( - agent_id: str | uuid.UUID, - message: str = "Agent failed", - percentage: int = 100, - session: Session | None = None, -) -> bool: - """Mark an agent activity as failed. - - Args: - agent_id: The agent_id of the agent to mark as failed - message: Log message (default: "Agent failed") - percentage: Completion percentage (default: 100) - session: Optional SQLAlchemy session. If not provided, uses ScopedSession. - - Returns: - True if successful - - Raises: - ValueError: If agent_id not found - """ - db = session or get_global_session() - - Activity.append_log( - session=db, - agent_id=normalize_agent_id(agent_id), - message=message, - status=Status.ERROR, - percentage=percentage, - ) - return True - - -def cancel_pending( - session: Session | None = None, -) -> int: - """Mark all queued and running agents as canceled. - - Useful during application shutdown to clean up pending tasks. - - Returns: - Number of agents that were canceled - """ - db = session or get_global_session() - - pending_agent_ids = Activity.get_pending_ids(db) - for agent_id in pending_agent_ids: - Activity.append_log( - session=db, - agent_id=agent_id, - message="Canceled due to shutdown", - status=Status.CANCELED, - percentage=None, - ) - - db.commit() - return len(pending_agent_ids) - - -def list( - session: Session, - page: int = 1, - page_size: int = 50, - metadata_filter: dict[str, Any] | None = None, -) -> ActivityListSchema: - """List activities with pagination. - - Args: - session: SQLAlchemy session to use for the query - page: Page number (1-indexed) - page_size: Number of items per page - metadata_filter: Optional dict of key-value pairs to filter by. - Activities must have metadata containing all specified keys - with exactly matching values. - - Returns: - ActivityList with list of ActivityListItemSchema items - """ - # Build base query for total count - query = session.query(Activity) - if metadata_filter: - for key, value in metadata_filter.items(): - query = query.filter(Activity.metadata_[key].as_string() == str(value)) - total = query.count() - - rows = Activity.get_list( - session, - page=page, - page_size=page_size, - metadata_filter=metadata_filter, - ) - - return ActivityListSchema( - items=[ActivityListItemSchema.model_validate(row) for row in rows], - total=total, - page=page, - page_size=page_size, - ) - - -def detail( - session: Session, - agent_id: str | uuid.UUID, - metadata_filter: dict[str, Any] | None = None, -) -> ActivityDetailSchema | None: - """Get a single activity by agent_id with all logs. - - Args: - session: SQLAlchemy session to use for the query - agent_id: The agent_id to look up - metadata_filter: Optional dict of key-value pairs to filter by. - If provided and the activity's metadata doesn't match, - returns None (same as if not found). - - Returns: - ActivityDetailSchema with full log history, or None if not found - or if metadata doesn't match - """ - if item := Activity.get_by_agent_id(session, agent_id, metadata_filter=metadata_filter): - return ActivityDetailSchema.model_validate(item) - return None - - -def count_active(session: Session) -> int: - """Get count of active (queued or running) agents. - - Args: - session: SQLAlchemy session to use for the query - - Returns: - Count of agents with QUEUED or RUNNING status - """ - return Activity.get_active_count(session) diff --git a/src/agentexec/config.py b/src/agentexec/config.py index 0f9f12a..56092b8 100644 --- a/src/agentexec/config.py +++ b/src/agentexec/config.py @@ -17,10 +17,10 @@ class Config(BaseSettings): description="Prefix for database table names", validation_alias="AGENTEXEC_TABLE_PREFIX", ) - queue_name: str = Field( + queue_prefix: str = Field( default="agentexec_tasks", - description="Name of the Redis list to use as task queue", - validation_alias="AGENTEXEC_QUEUE_NAME", + description="Prefix for task queue keys. Partition queues are {prefix}:{lock_key}.", + validation_alias=AliasChoices("AGENTEXEC_QUEUE_PREFIX", "AGENTEXEC_QUEUE_NAME"), ) num_workers: int = Field( default=4, @@ -72,22 +72,61 @@ class Config(BaseSettings): result_ttl: int = Field( default=3600, - description="TTL in seconds for task results in Redis", + description="TTL in seconds for task results", validation_alias="AGENTEXEC_RESULT_TTL", ) state_backend: str = Field( - default="agentexec.state.redis_backend", - description="State backend to use (fully-qualified module path)", + default="agentexec.state.redis", + description="State backend: 'agentexec.state.redis' or 'agentexec.state.kafka'", validation_alias="AGENTEXEC_STATE_BACKEND", ) + kafka_bootstrap_servers: str | None = Field( + default=None, + description="Kafka bootstrap servers (e.g. 'localhost:9092')", + validation_alias=AliasChoices( + "AGENTEXEC_KAFKA_BOOTSTRAP_SERVERS", "KAFKA_BOOTSTRAP_SERVERS" + ), + ) + kafka_default_partitions: int = Field( + default=6, + description="Default number of partitions for auto-created topics", + validation_alias="AGENTEXEC_KAFKA_DEFAULT_PARTITIONS", + ) + kafka_replication_factor: int = Field( + default=1, + description="Replication factor for auto-created topics", + validation_alias="AGENTEXEC_KAFKA_REPLICATION_FACTOR", + ) + kafka_max_batch_size: int = Field( + default=16384, + description="Producer max batch size in bytes", + validation_alias="AGENTEXEC_KAFKA_MAX_BATCH_SIZE", + ) + kafka_linger_ms: int = Field( + default=5, + description="Producer linger time in milliseconds", + validation_alias="AGENTEXEC_KAFKA_LINGER_MS", + ) + kafka_retention_ms: int = Field( + default=-1, + description="Retention for compacted topics in ms (-1 = forever)", + validation_alias="AGENTEXEC_KAFKA_RETENTION_MS", + ) + key_prefix: str = Field( default="agentexec", description="Prefix for state backend keys", validation_alias="AGENTEXEC_KEY_PREFIX", ) + scheduler_poll_interval: int = Field( + default=10, + description="Seconds between schedule polls", + validation_alias="AGENTEXEC_SCHEDULER_POLL_INTERVAL", + ) + scheduler_timezone: str = Field( default="UTC", description=( @@ -96,14 +135,24 @@ class Config(BaseSettings): ), validation_alias="AGENTEXEC_SCHEDULER_TIMEZONE", ) + max_task_retries: int = Field( + default=3, + description=( + "Maximum number of times a failed task will be retried before " + "being marked as a permanent error. Set to 0 to disable retries. " + "With the Kafka backend, retries preserve partition ordering — " + "the task stays in its original position in the queue." + ), + validation_alias="AGENTEXEC_MAX_TASK_RETRIES", + ) + lock_ttl: int = Field( default=1800, description=( - "TTL in seconds for task lock keys in Redis. " - "This is a safety net for worker process death (OOM, SIGKILL) — " + "TTL in seconds for task lock keys (Redis backend only). " + "Safety net for worker process death (OOM, SIGKILL) — " "locks are always explicitly released on task completion or error. " - "Set this higher than your longest expected task duration to avoid " - "premature lock expiry while a task is still running." + "Ignored by the Kafka backend (partition assignment handles isolation)." ), validation_alias="AGENTEXEC_LOCK_TTL", ) diff --git a/src/agentexec/core/db.py b/src/agentexec/core/db.py index e5f1a00..94609a9 100644 --- a/src/agentexec/core/db.py +++ b/src/agentexec/core/db.py @@ -1,62 +1,37 @@ from sqlalchemy import Engine -from sqlalchemy.orm import DeclarativeBase, Session, scoped_session, sessionmaker +from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker __all__ = [ "Base", - "get_global_session", - "set_global_session", - "remove_global_session", + "configure_engine", + "get_session", ] class Base(DeclarativeBase): - """Base class for all SQLAlchemy models in agent-runner. - - Example: - # In alembic/env.py - import agentexec as ax - target_metadata = ax.Base.metadata - """ - + """Base class for all SQLAlchemy models.""" pass -# We need one session per worker process with a shared engine across the application. -# SQLAlchemy's scoped_session provides process-local session management out of the box. -_session_factory: scoped_session[Session] = scoped_session(sessionmaker()) - - -def set_global_session(engine: Engine) -> None: - """Configure the global session factory with an engine. - - Called by workers on startup to bind the session to their database. - - Args: - engine: SQLAlchemy engine to bind sessions to. - """ - _session_factory.configure(bind=engine) +_engine: Engine | None = None +_session_factory: sessionmaker[Session] | None = None -def get_global_session() -> Session: - """Get the worker's process-local session. +def configure_engine(engine: Engine) -> None: + """Set the shared engine for the application.""" + global _engine, _session_factory + _engine = engine + _session_factory = sessionmaker(bind=engine) - This is distinct from request-scoped sessions used in API handlers. - Use this for background task execution within workers. - Returns: - A session bound to the configured engine. +def get_session() -> Session: + """Create a new session from the shared engine. - Raises: - RuntimeError: If set_global_session() hasn't been called. + Use with a context manager: + with get_session() as db: + db.query(...) """ + if _session_factory is None: + raise RuntimeError("Database engine not configured. Call configure_engine() first.") return _session_factory() - - -def remove_global_session() -> None: - """Close and remove the worker's process-local session. - - Called during worker cleanup to close the session and return - connections to the pool. - """ - _session_factory.remove() diff --git a/src/agentexec/core/logging.py b/src/agentexec/core/logging.py index 9a79565..0df26df 100644 --- a/src/agentexec/core/logging.py +++ b/src/agentexec/core/logging.py @@ -1,9 +1,3 @@ -"""Unified logging for main and worker processes. - -Uses multiprocessing's built-in logger which handles cross-process -logging correctly on macOS (spawn mode). -""" - import logging import multiprocessing diff --git a/src/agentexec/core/queue.py b/src/agentexec/core/queue.py index e1dc2cf..d527b6c 100644 --- a/src/agentexec/core/queue.py +++ b/src/agentexec/core/queue.py @@ -1,24 +1,16 @@ -import json from enum import Enum from typing import Any from pydantic import BaseModel -from agentexec import state -from agentexec.config import CONF from agentexec.core.logging import get_logger from agentexec.core.task import Task +from agentexec.state import backend logger = get_logger(__name__) class Priority(str, Enum): - """Task priority levels. - - HIGH: Push to front of queue (processed first). - LOW: Push to back of queue (processed later). - """ - HIGH = "high" LOW = "low" @@ -28,105 +20,40 @@ async def enqueue( context: BaseModel, *, priority: Priority = Priority.LOW, - queue_name: str | None = None, metadata: dict[str, Any] | None = None, ) -> Task: """Enqueue a task for background execution. - Pushes the task to the queue for worker processing. The task must be - registered with a WorkerPool via @pool.task() decorator. + Creates an activity record, serializes the context, and pushes the + task to the queue for workers to process. Args: - task_name: Name of the task to execute. - context: Task context as a Pydantic BaseModel. - priority: Task priority (Priority.HIGH or Priority.LOW). - queue_name: Queue name. Defaults to CONF.queue_name. - metadata: Optional dict of arbitrary metadata to attach to the activity. - Useful for multi-tenancy (e.g., {"organization_id": "org-123"}). + task_name: Name of the registered task (must match a ``@pool.task()``). + context: Pydantic model with the task's input data. + priority: ``Priority.HIGH`` pushes to the front of the queue. + metadata: Optional dict attached to the activity record (e.g. + ``{"organization_id": "org-123"}`` for multi-tenancy). Returns: - Task instance with typed context and agent_id for tracking. - - Example: - @pool.task("research_company") - async def research(agent_id: UUID, context: ResearchContext): - ... + The created Task with its ``agent_id`` for tracking. - task = await ax.enqueue("research_company", ResearchContext(company="Acme")) + Example:: - # With metadata for multi-tenancy - task = await ax.enqueue( - "research_company", - ResearchContext(company="Acme"), - metadata={"organization_id": "org-123"} - ) + task = await ax.enqueue("research", ResearchContext(company="Acme")) + print(task.agent_id) # UUID for tracking """ - push_func = { - Priority.HIGH: state.backend.rpush, - Priority.LOW: state.backend.lpush, - }[priority] - - task = Task.create( + task = await Task.create( task_name=task_name, context=context, metadata=metadata, ) - push_func( - queue_name or CONF.queue_name, + + await backend.queue.push( task.model_dump_json(), + high_priority=(priority == Priority.HIGH), ) logger.info(f"Enqueued task {task.task_name} with agent_id {task.agent_id}") return task -def requeue( - task: Task, - *, - queue_name: str | None = None, -) -> int: - """Push a task back to the end of the queue. - - Used when a task's lock cannot be acquired — the task is returned to the - queue so it can be retried after the lock is released. - - Args: - task: Task to requeue. - queue_name: Queue name. Defaults to CONF.queue_name. - - Returns: - Length of the queue after the push. - """ - return state.backend.lpush( - queue_name or CONF.queue_name, - task.model_dump_json(), - ) - - -async def dequeue( - *, - queue_name: str | None = None, - timeout: int = 1, -) -> dict[str, Any] | None: - """Dequeue a task from the queue. - - Blocks for up to timeout seconds waiting for a task. - - Args: - queue_name: Queue name. Defaults to CONF.queue_name. - timeout: Maximum seconds to wait for a task. - - Returns: - Parsed task data if available, None otherwise. - """ - result = await state.backend.brpop( - queue_name or CONF.queue_name, - timeout=timeout, - ) - - if result is None: - return None - - _, task_data = result - data: dict[str, Any] = json.loads(task_data) - return data diff --git a/src/agentexec/core/results.py b/src/agentexec/core/results.py index 204f8ff..2f3c1a2 100644 --- a/src/agentexec/core/results.py +++ b/src/agentexec/core/results.py @@ -6,7 +6,7 @@ from pydantic import BaseModel -from agentexec import state +from agentexec.state import KEY_RESULT, backend if TYPE_CHECKING: from agentexec.core.task import Task @@ -15,26 +15,18 @@ DEFAULT_TIMEOUT: int = 300 # TODO improve this polling approach -async def get_result(task: Task, timeout: int = DEFAULT_TIMEOUT) -> BaseModel: - """Poll for a task result. - - Waits for a task to complete and returns its result. - Uses automatic type reconstruction from serialized class information. - - Args: - task: The Task instance to wait for - timeout: Maximum seconds to wait for result +async def _get_result(agent_id: str) -> BaseModel | None: + key = backend.format_key(*KEY_RESULT, str(agent_id)) + data = await backend.state.get(key) + return backend.deserialize(data) if data else None - Returns: - Deserialized result as BaseModel instance - Raises: - TimeoutError: If result not available within timeout - """ +async def get_result(task: Task, timeout: int = DEFAULT_TIMEOUT) -> BaseModel: + """Poll for a task result.""" start = time.time() while time.time() - start < timeout: - result = await state.aget_result(task.agent_id) + result = await _get_result(task.agent_id) if result is not None: return result await asyncio.sleep(0.5) @@ -43,22 +35,6 @@ async def get_result(task: Task, timeout: int = DEFAULT_TIMEOUT) -> BaseModel: async def gather(*tasks: Task, timeout: int = DEFAULT_TIMEOUT) -> tuple[BaseModel, ...]: - """Wait for multiple tasks and return their results. - - Similar to asyncio.gather, but for background tasks. - - Args: - *tasks: Task instances to wait for - timeout: Maximum seconds to wait for each result - - Returns: - Tuple of deserialized results as BaseModel instances - - Example: - brand = await ax.enqueue("brand_research", ctx) - market = await ax.enqueue("market_research", ctx) - - brand_result, market_result = await ax.gather(brand, market) - """ + """Wait for multiple tasks and return their results.""" results = await asyncio.gather(*[get_result(task, timeout) for task in tasks]) return tuple(results) diff --git a/src/agentexec/core/task.py b/src/agentexec/core/task.py index ae386c3..ab6a88d 100644 --- a/src/agentexec/core/task.py +++ b/src/agentexec/core/task.py @@ -1,13 +1,15 @@ from __future__ import annotations import inspect +from collections.abc import Mapping from typing import Any, Protocol, TypeAlias, TypeVar, cast, get_type_hints from uuid import UUID -from pydantic import BaseModel, ConfigDict, PrivateAttr, field_serializer +from pydantic import BaseModel, ConfigDict -from agentexec import activity, state +from agentexec import activity from agentexec.config import CONF +from agentexec.state import KEY_RESULT, backend TaskResult: TypeAlias = BaseModel @@ -16,8 +18,6 @@ class _SyncTaskHandler(Protocol[ContextT, ResultT]): - """Protocol for sync task handler functions.""" - __name__: str def __call__( @@ -29,8 +29,6 @@ def __call__( class _AsyncTaskHandler(Protocol[ContextT, ResultT]): - """Protocol for async task handler functions.""" - __name__: str async def __call__( @@ -42,37 +40,20 @@ async def __call__( # TODO: Using Any,Any here because of contravariance limitations with function parameters. -# A function accepting MyContext (specific) is not statically assignable to one expecting -# BaseModel (general). Runtime validation in TaskDefinition._infer_context_type catches -# invalid context/return types. Revisit if Python typing evolves to support this pattern. TaskHandler: TypeAlias = _SyncTaskHandler[Any, Any] | _AsyncTaskHandler[Any, Any] class TaskDefinition: """Definition of a task type (created at registration time). - Encapsulates the handler function and its metadata (context class, etc.). + Encapsulates the handler function and its metadata (context class, lock key). One TaskDefinition can spawn many Task instances. - - This object is created once when a task is registered via @pool.task(), - and acts as a factory to reconstruct Task instances from the queue with - properly typed context. - - Example: - @pool.task("research_company") - async def research(agent_id: UUID, context: ResearchContext): - print(context.company_name) - - # TaskDefinition captures ResearchContext from the type hint - # and uses it to deserialize tasks from the queue """ name: str handler: TaskHandler context_type: type[BaseModel] - # Optional: only set if handler returns a BaseModel subclass result_type: type[BaseModel] | None - # Optional: string template evaluated against context for distributed locking lock_key: str | None def __init__( @@ -84,179 +65,111 @@ def __init__( result_type: type[BaseModel] | None = None, lock_key: str | None = None, ) -> None: - """Initialize task definition. - - Args: - name: Task type name - handler: Handler function (sync or async) - context_type: Optional explicit context type (inferred from annotations if not provided). - result_type: Optional explicit result type (inferred from annotations if not provided). - lock_key: Optional string template for distributed locking. Evaluated against - context fields (e.g., "user:{user_id}"). When set, only one task with - the same evaluated lock key can run at a time. - - Raises: - TypeError: If handler doesn't have a typed 'context' parameter with BaseModel subclass - """ self.name = name self.handler = handler self.context_type = context_type or self._infer_context_type(handler) self.result_type = result_type or self._infer_result_type(handler) self.lock_key = lock_key - async def __call__(self, agent_id: UUID, context: BaseModel) -> TaskResult: - """Delegate calls to the handler function.""" - if inspect.iscoroutinefunction(self.handler): - handler = cast(_AsyncTaskHandler, self.handler) - return await handler(agent_id=agent_id, context=context) - else: - handler = cast(_SyncTaskHandler, self.handler) - return handler(agent_id=agent_id, context=context) + def get_lock_key(self, context: Mapping[str, Any]) -> str | None: + """Evaluate the lock key template against context data.""" + return self.lock_key.format(**context) if self.lock_key else None - def _infer_context_type(self, handler: TaskHandler) -> type[BaseModel]: - """Infer context class from handler's type annotations. + def hydrate_context(self, context: Mapping[str, Any]) -> BaseModel: + """Validate raw context data into the registered Pydantic model.""" + return self.context_type.model_validate(context) + + async def execute(self, task: Task) -> TaskResult | None: + """Execute the task handler and manage its lifecycle. + + Handles activity tracking (started/complete/error) and result storage. + """ + context = self.hydrate_context(task.context) - Looks for a 'context' parameter with a Pydantic BaseModel type hint. + await activity.update( + agent_id=task.agent_id, + message=CONF.activity_message_started, + percentage=0, + ) - Args: - handler: The task handler function + try: + if inspect.iscoroutinefunction(self.handler): + handler = cast(_AsyncTaskHandler, self.handler) + result = await handler(agent_id=task.agent_id, context=context) + else: + handler = cast(_SyncTaskHandler, self.handler) + result = handler(agent_id=task.agent_id, context=context) - Returns: - Context class (BaseModel subclass) + if isinstance(result, BaseModel): + key = backend.format_key(*KEY_RESULT, str(task.agent_id)) + await backend.state.set(key, backend.serialize(result), ttl_seconds=CONF.result_ttl) - Raises: - TypeError: If 'context' parameter is missing or not a BaseModel subclass - """ + await activity.update( + agent_id=task.agent_id, + message=CONF.activity_message_complete, + percentage=100, + status=activity.Status.COMPLETE, + ) + return result + except Exception as e: + await activity.update( + agent_id=task.agent_id, + message=CONF.activity_message_error.format(error=e), + status=activity.Status.ERROR, + ) + raise + + def _infer_context_type(self, handler: TaskHandler) -> type[BaseModel]: hints = get_type_hints(handler) if "context" not in hints: raise TypeError( f"Task handler '{handler.__name__}' must have a 'context' parameter " f"with a BaseModel type annotation" ) - context_type = hints["context"] if not (inspect.isclass(context_type) and issubclass(context_type, BaseModel)): raise TypeError( f"Task handler '{handler.__name__}' context parameter must be a " f"BaseModel subclass, got {context_type}" ) - return context_type def _infer_result_type(self, handler: TaskHandler) -> type[BaseModel] | None: - """Infer result class from handler's return type annotation. - - Looks for a return annotation with a Pydantic BaseModel type hint. - - Args: - handler: The task handler function - - Returns: - Result class (BaseModel subclass) or None if return type is not BaseModel - """ hints = get_type_hints(handler) if "return" not in hints: return None - return_type = hints["return"] if not (inspect.isclass(return_type) and issubclass(return_type, BaseModel)): return None - return return_type class Task(BaseModel): - """Represents a background task instance. + """A background task instance — pure data, no behavior. - Tasks are serialized to JSON and enqueued to Redis for workers to process. - Each task has a type (matching a registered TaskDefinition), a typed context, - and an agent_id for tracking. + Tasks are serialized to JSON and pushed to the queue. Workers pop them, + look up the TaskDefinition by task_name, and execute via the definition. - The context is stored as its native Pydantic type. Serialization to dict - happens automatically via field_serializer when dumping to JSON. - - After deserialization, call bind() to attach the TaskDefinition, then - execute() to run the task handler. - - Example: - # Create with typed context - ctx = ResearchContext(company_name="Anthropic") - task = Task.create("research", ctx) - task.context.company_name # Typed access! - - # Serialize to JSON for Redis (context becomes dict) - json_str = task.model_dump_json() - - # Worker deserializes and executes - task = Task.from_serialized(task_def, data) - await task.execute() + Context is stored as a raw dict. The TaskDefinition hydrates it into + the registered Pydantic model at execution time. """ model_config = ConfigDict(arbitrary_types_allowed=True) task_name: str - context: BaseModel + context: Mapping[str, Any] agent_id: UUID - _definition: TaskDefinition | None = PrivateAttr(default=None) - - @field_serializer("context") - def serialize_context(self, value: BaseModel) -> dict[str, Any]: - """Serialize context to dict for JSON storage.""" - return value.model_dump(mode="json") + retry_count: int = 0 @classmethod - def from_serialized(cls, definition: TaskDefinition, data: dict[str, Any]) -> Task: - """Create a Task from serialized data with its definition bound. - - Args: - definition: The TaskDefinition containing the handler and context_type - data: Serialized task data with task_name, context, and agent_id - - Returns: - Task instance with typed context and bound definition - """ - task = cls( - task_name=data["task_name"], - context=definition.context_type.model_validate(data["context"]), - agent_id=data["agent_id"], - ) - task._definition = definition - return task - - @classmethod - def create( + async def create( cls, task_name: str, context: BaseModel, metadata: dict[str, Any] | None = None, ) -> Task: - """Create a new task with automatic activity tracking. - - This is a convenience method that creates both a Task instance and - its corresponding activity record in one step. - - Args: - task_name: Name/type of the task (e.g., "research", "analysis") - context: Task context as a Pydantic model - metadata: Optional dict of arbitrary metadata to attach to the activity. - Useful for multi-tenancy (e.g., {"organization_id": "org-123"}). - - Returns: - Task instance with agent_id set - - Example: - ctx = ResearchContext(company="Acme") - task = Task.create("research_company", ctx) - task.context.company # Typed access - - # With metadata for multi-tenancy - task = Task.create( - "research_company", - ctx, - metadata={"organization_id": "org-123"} - ) - """ - agent_id = activity.create( + """Create a new task with automatic activity tracking.""" + agent_id = await activity.create( task_name=task_name, message=CONF.activity_message_create, metadata=metadata, @@ -264,73 +177,6 @@ def create( return cls( task_name=task_name, - context=context, + context=context.model_dump(mode="json"), agent_id=agent_id, ) - - def get_lock_key(self) -> str | None: - """Evaluate the lock key template against the task context. - - Returns: - Evaluated lock key string, or None if no lock_key is configured. - - Raises: - RuntimeError: If task has not been bound to a definition. - KeyError: If the template references a field not present in the context. - """ - if self._definition is None: - raise RuntimeError("Task must be bound to a definition before getting lock key") - - if self._definition.lock_key is None: - return None - - return self._definition.lock_key.format(**self.context.model_dump()) - - async def execute(self) -> TaskResult | None: - """Execute the task using its bound definition's handler. - - Manages task lifecycle: marks started, runs handler, marks completed/errored. - - Returns: - Handler return value, or None if handler raised an exception - - Raises: - RuntimeError: If task has not been bound to a definition - """ - if self._definition is None: - raise RuntimeError("Task must be bound to a definition before execution") - - activity.update( - agent_id=self.agent_id, - message=CONF.activity_message_started, - percentage=0, - ) - - try: - result = await self._definition( - agent_id=self.agent_id, - context=self.context, - ) - - # TODO ensure we are properly supporting None return values - if isinstance(result, BaseModel): - await state.aset_result( - self.agent_id, - result, - ttl_seconds=CONF.result_ttl, - ) - - activity.update( - agent_id=self.agent_id, - message=CONF.activity_message_complete, - percentage=100, - status=activity.Status.COMPLETE, - ) - return result - except Exception as e: - activity.update( - agent_id=self.agent_id, - message=CONF.activity_message_error.format(error=e), - status=activity.Status.ERROR, - ) - return None diff --git a/src/agentexec/pipeline.py b/src/agentexec/pipeline.py index 9801e79..8491d80 100644 --- a/src/agentexec/pipeline.py +++ b/src/agentexec/pipeline.py @@ -374,7 +374,7 @@ async def _run_task( _context: StepResult = context for i, step in enumerate(steps): - activity.update( + await activity.update( agent_id, f"Started {step.description}", percentage=int((i / total_steps) * 100), diff --git a/src/agentexec/runners/base.py b/src/agentexec/runners/base.py index d7881ad..d0b1278 100644 --- a/src/agentexec/runners/base.py +++ b/src/agentexec/runners/base.py @@ -117,7 +117,7 @@ def report_status(self) -> Any: agent_id = self._agent_id assert agent_id, "agent_id must be set to use report_status tool" - def report_activity(message: str, percentage: int) -> str: + async def report_activity(message: str, percentage: int) -> str: """Report progress and status updates. Use this tool to report your progress as you work through the task. @@ -129,7 +129,7 @@ def report_activity(message: str, percentage: int) -> str: Returns: Confirmation message """ - activity.update( + await activity.update( agent_id=agent_id, message=message, percentage=percentage, diff --git a/src/agentexec/schedule.py b/src/agentexec/schedule.py index 8dc5b1d..5fa280a 100644 --- a/src/agentexec/schedule.py +++ b/src/agentexec/schedule.py @@ -4,30 +4,23 @@ from datetime import datetime from typing import Any from croniter import croniter -from pydantic import BaseModel, Field, ValidationError +from pydantic import BaseModel, Field -from agentexec import state from agentexec.config import CONF from agentexec.core.logging import get_logger -from agentexec.core.queue import enqueue +from agentexec.state import backend logger = get_logger(__name__) __all__ = [ "register", - "tick", ] REPEAT_FOREVER: int = -1 class ScheduledTask(BaseModel): - """A task scheduled to run on a recurring interval. - - Stored in Redis with a sorted-set index for efficient due-time polling. - Each time it fires, a fresh Task (with its own agent_id) is enqueued - for the worker pool. Stays in Redis until its repeat budget is exhausted. - """ + """A task scheduled to run on a recurring interval.""" task_name: str context: bytes @@ -37,18 +30,18 @@ class ScheduledTask(BaseModel): created_at: float = Field(default_factory=lambda: time.time()) metadata: dict[str, Any] | None = None + @property + def key(self) -> str: + """Unique identity: task_name + cron + context hash.""" + import hashlib + context_hash = hashlib.md5(self.context).hexdigest()[:8] + return f"{self.task_name}:{self.cron}:{context_hash}" + def model_post_init(self, __context: Any) -> None: - """Compute next_run from cron if not explicitly set.""" if self.next_run == 0: self.next_run = self._next_after(self.created_at) def advance(self) -> None: - """Advance next_run to the next future cron occurrence. - - Skips past any missed intervals so we don't enqueue a burst of - catch-up tasks after downtime. Decrements repeat for each skipped - interval (finite schedules only; -1 stays unchanged). - """ now = time.time() while True: self.next_run = self._next_after(self.next_run) @@ -58,22 +51,11 @@ def advance(self) -> None: break def _next_after(self, anchor: float) -> float: - """Compute the next cron occurrence after anchor.""" dt = datetime.fromtimestamp(anchor, tz=CONF.scheduler_tz) return float(croniter(self.cron, dt).get_next(float)) -def _schedule_key(schedule_id: str) -> str: - """Redis key for a schedule definition.""" - return state.backend.format_key(*state.KEY_SCHEDULE, schedule_id) - - -def _queue_key() -> str: - """Redis sorted-set key that indexes schedules by next_run.""" - return state.backend.format_key(*state.KEY_SCHEDULE_QUEUE) - - -def register( +async def register( task_name: str, every: str, context: BaseModel, @@ -81,64 +63,15 @@ def register( repeat: int = REPEAT_FOREVER, metadata: dict[str, Any] | None = None, ) -> None: - """Register a new scheduled task in Redis. - - The task will first fire at the next cron occurrence from now. - - Args: - task_name: Name of the registered task to enqueue on each tick. - every: Schedule expression (cron syntax: min hour dom mon dow). - context: Pydantic context payload passed to the handler each time. - repeat: How many additional executions after the first. - -1 = forever (default), 0 = one-shot, N = N more times. - metadata: Optional metadata dict (e.g. for multi-tenancy). - """ + """Register a new scheduled task.""" task = ScheduledTask( task_name=task_name, - context=state.backend.serialize(context), + context=backend.serialize(context), cron=every, repeat=repeat, metadata=metadata, ) - - state.backend.set( - _schedule_key(task_name), - task.model_dump_json().encode(), - ) - state.backend.zadd(_queue_key(), {task_name: task.next_run}) + await backend.schedule.register(task) logger.info(f"Scheduled {task_name}") -async def tick() -> None: - """Process all scheduled tasks that are due right now. - - For each due task, enqueues it into the normal task queue. If repeats - remain, advances to the next run time. Otherwise removes the schedule. - """ - for _task_name in await state.backend.zrangebyscore(_queue_key(), 0, time.time()): - task_name = _task_name.decode("utf-8") - - try: - data = state.backend.get(_schedule_key(task_name)) - task = ScheduledTask.model_validate_json(data) - except ValidationError: - logger.warning(f"Failed to load schedule {task_name}, skipping") - continue - - await enqueue( - task.task_name, - context=state.backend.deserialize(task.context), - metadata=task.metadata, - ) - - if task.repeat == 0: - state.backend.zrem(_queue_key(), task_name) - state.backend.delete(_schedule_key(task_name)) - logger.info(f"Schedule for '{task_name}' exhausted") - else: - task.advance() - state.backend.set( - _schedule_key(task_name), - task.model_dump_json().encode(), - ) - state.backend.zadd(_queue_key(), {task_name: task.next_run}) diff --git a/src/agentexec/state/__init__.py b/src/agentexec/state/__init__.py index 1bc797e..e670d5b 100644 --- a/src/agentexec/state/__init__.py +++ b/src/agentexec/state/__init__.py @@ -1,266 +1,38 @@ -# cspell:ignore acheck +"""State management layer. -from typing import cast, AsyncGenerator, Coroutine -import importlib -from uuid import UUID +Initializes the configured backend and exposes it as a public reference. +All state operations go through ``backend.state``, ``backend.queue``, +and ``backend.schedule`` directly. Activity uses Postgres directly. + +Pick one backend via AGENTEXEC_STATE_BACKEND: + - 'agentexec.state.redis_backend' (default) + - 'agentexec.state.kafka_backend' +""" + +from __future__ import annotations -from pydantic import BaseModel +import importlib from agentexec.config import CONF -from agentexec.state.backend import StateBackend +from agentexec.state.base import BaseBackend KEY_RESULT = (CONF.key_prefix, "result") KEY_EVENT = (CONF.key_prefix, "event") -KEY_LOCK = (CONF.key_prefix, "lock") -KEY_SCHEDULE = (CONF.key_prefix, "schedule") -KEY_SCHEDULE_QUEUE = (CONF.key_prefix, "schedule_queue") -CHANNEL_LOGS = (CONF.key_prefix, "logs") - -__all__ = [ - "backend", - "get_result", - "aget_result", - "set_result", - "aset_result", - "delete_result", - "adelete_result", - "publish_log", - "subscribe_logs", - "set_event", - "clear_event", - "check_event", - "acheck_event", - "acquire_lock", - "release_lock", - "clear_keys", -] - - -def _load_backend(module_name: str) -> StateBackend: - module = cast(StateBackend, importlib.import_module(module_name)) - if not isinstance(module, StateBackend): # type: ignore[invalid-argument-type] - raise RuntimeError(f"State backend ({module_name}) does not conform to protocol.") - return module - - -backend: StateBackend = _load_backend(CONF.state_backend) - - -def get_result(agent_id: UUID | str) -> BaseModel | None: - """Get result for an agent (sync). - - Returns deserialized BaseModel instance with automatic type reconstruction. - - Args: - agent_id: Unique agent identifier (UUID or string) - - Returns: - Deserialized BaseModel or None if not found - """ - data = backend.get(backend.format_key(*KEY_RESULT, str(agent_id))) - return backend.deserialize(data) if data else None - - -def aget_result(agent_id: UUID | str) -> Coroutine[None, None, BaseModel | None]: - """Get result for an agent (async). - - Returns deserialized BaseModel instance with automatic type reconstruction. - - Args: - agent_id: Unique agent identifier (UUID or string) - - Returns: - Coroutine that resolves to deserialized BaseModel or None if not found - """ - - async def _get() -> BaseModel | None: - data = await backend.aget(backend.format_key(*KEY_RESULT, str(agent_id))) - return backend.deserialize(data) if data else None - - return _get() - - -def set_result( - agent_id: UUID | str, - data: BaseModel, - ttl_seconds: int | None = None, -) -> bool: - """Set result for an agent (sync). - - Args: - agent_id: Unique agent identifier (UUID or string) - data: Result data (must be Pydantic BaseModel) - ttl_seconds: Optional time-to-live in seconds - - Returns: - True if successful - """ - return backend.set( - backend.format_key(*KEY_RESULT, str(agent_id)), - backend.serialize(data), - ttl_seconds=ttl_seconds, - ) - - -def aset_result( - agent_id: UUID | str, - data: BaseModel, - ttl_seconds: int | None = None, -) -> Coroutine[None, None, bool]: - """Set result for an agent (async). - - Args: - agent_id: Unique agent identifier (UUID or string) - data: Result data (must be Pydantic BaseModel) - ttl_seconds: Optional time-to-live in seconds - - Returns: - Coroutine that resolves to True if successful - """ - return backend.aset( - backend.format_key(*KEY_RESULT, str(agent_id)), - backend.serialize(data), - ttl_seconds=ttl_seconds, - ) - - -def delete_result(agent_id: UUID | str) -> int: - """Delete result for an agent (sync). - - Args: - agent_id: Unique agent identifier (UUID or string) - - Returns: - Number of keys deleted (0 or 1) - """ - return backend.delete(backend.format_key(*KEY_RESULT, str(agent_id))) - - -def adelete_result(agent_id: UUID | str) -> Coroutine[None, None, int]: - """Delete result for an agent (async). - - Args: - agent_id: Unique agent identifier (UUID or string) - Returns: - Coroutine that resolves to number of keys deleted (0 or 1) - """ - return backend.adelete(backend.format_key(*KEY_RESULT, str(agent_id))) - - -def publish_log(message: str) -> None: - """Publish a log message to the log channel (sync). - - Args: - message: Log message to publish (should be JSON string) - """ - backend.publish(backend.format_key(*CHANNEL_LOGS), message) - - -def subscribe_logs() -> AsyncGenerator[str, None]: - """Subscribe to log messages (async generator). - - Yields: - Log messages from the channel - """ - return backend.subscribe(backend.format_key(*CHANNEL_LOGS)) - - -def set_event(name: str, id: str) -> bool: - """Set an event flag. - - Args: - name: Event name (e.g., "shutdown", "ready") - id: Event identifier (e.g., pool id) - - Returns: - True if successful - """ - return backend.set(backend.format_key(*KEY_EVENT, name, id), b"1") - - -def clear_event(name: str, id: str) -> int: - """Clear an event flag. - - Args: - name: Event name (e.g., "shutdown", "ready") - id: Event identifier (e.g., pool id) - - Returns: - Number of keys deleted (0 or 1) - """ - return backend.delete(backend.format_key(*KEY_EVENT, name, id)) - - -def check_event(name: str, id: str) -> bool: - """Check if an event flag is set (sync). - Args: - name: Event name (e.g., "shutdown", "ready") - id: Event identifier (e.g., pool id) +def _create_backend(state_backend: str) -> BaseBackend: + """Instantiate the given backend class. - Returns: - True if event is set, False otherwise + The state_backend string is a fully qualified module path containing + a Backend class (e.g. 'agentexec.state.kafka'). """ - return backend.get(backend.format_key(*KEY_EVENT, name, id)) is not None + try: + module = importlib.import_module(state_backend) + return module.Backend() + except ImportError as e: + raise ImportError(f"Could not import backend {state_backend}: {e}") + except AttributeError: + raise ValueError(f"Backend module {state_backend} has no Backend class") -def acheck_event(name: str, id: str) -> Coroutine[None, None, bool]: - """Check if an event flag is set (async). - - Args: - name: Event name (e.g., "shutdown", "ready") - id: Event identifier (e.g., pool id) - - Returns: - Coroutine that resolves to True if event is set, False otherwise - """ - - async def _check() -> bool: - return await backend.aget(backend.format_key(*KEY_EVENT, name, id)) is not None - - return _check() - - -async def acquire_lock(lock_key: str, agent_id: str) -> bool: - """Attempt to acquire a task lock. - - Args: - lock_key: The evaluated lock key (e.g., "user:42") - agent_id: The agent_id holding the lock (for debugging) - - Returns: - True if lock was acquired, False if already held - """ - return await backend.acquire_lock( - backend.format_key(*KEY_LOCK, lock_key), - agent_id, - CONF.lock_ttl, - ) - - -async def release_lock(lock_key: str) -> int: - """Release a task lock. - - Args: - lock_key: The evaluated lock key (e.g., "user:42") - - Returns: - Number of keys deleted (0 or 1) - """ - return await backend.release_lock( - backend.format_key(*KEY_LOCK, lock_key), - ) - - -def clear_keys() -> int: - """Clear all state keys managed by this application. - - Removes all keys matching the configured prefix and the task queue. - This is useful during shutdown to prevent stale tasks from being - picked up on restart. - - Returns: - Total number of keys deleted - """ - return backend.clear_keys() +backend: BaseBackend = _create_backend(CONF.state_backend) diff --git a/src/agentexec/state/backend.py b/src/agentexec/state/backend.py deleted file mode 100644 index 34eb58c..0000000 --- a/src/agentexec/state/backend.py +++ /dev/null @@ -1,363 +0,0 @@ -from types import ModuleType -from typing import AsyncGenerator, Coroutine, Optional, Protocol, runtime_checkable - -from pydantic import BaseModel - - -@runtime_checkable -class StateBackend(Protocol): - """Protocol defining the state backend interface. - - This protocol defines all the operations needed for: - - Task queue management (priority queue operations) - - Result storage (with TTL support) - - Event coordination (shutdown flags, etc.) - - Pub/sub messaging (worker logging) - - Any module that implements these functions can serve as a state backend. - Methods are defined as @staticmethod to match module-level functions. - - Connection management is handled internally - connections are established - lazily when first accessed. Only cleanup needs to be explicit. - """ - - # Connection management - @staticmethod - async def close() -> None: - """Close all connections to the backend. - - This should close both async and sync connections and clean up - any resources. - """ - ... - - # Queue operations (Redis list commands) - @staticmethod - def rpush(key: str, value: str) -> int: - """Push value to the right (front) of the list - for high priority tasks. - - Args: - key: Redis list key - value: Serialized task data - - Returns: - Length of the list after the push - """ - ... - - @staticmethod - def lpush(key: str, value: str) -> int: - """Push value to the left (back) of the list - for low priority tasks. - - Args: - key: Redis list key - value: Serialized task data - - Returns: - Length of the list after the push - """ - ... - - @staticmethod - async def brpop(key: str, timeout: int = 0) -> Optional[tuple[str, str]]: - """Pop value from the right of the list with blocking. - - Args: - key: Redis list key - timeout: Timeout in seconds (0 = block forever) - - Returns: - Tuple of (key, value) or None if timeout - """ - ... - - # Key-value operations - @staticmethod - def aget(key: str) -> Coroutine[None, None, Optional[bytes]]: - """Get value for key asynchronously. - - Args: - key: Key to retrieve - - Returns: - Coroutine that resolves to value as bytes or None if not found - """ - ... - - @staticmethod - def get(key: str) -> Optional[bytes]: - """Get value for key synchronously. - - Args: - key: Key to retrieve - - Returns: - Value as bytes or None if not found - """ - ... - - @staticmethod - def aset( - key: str, value: bytes, ttl_seconds: Optional[int] = None - ) -> Coroutine[None, None, bool]: - """Set value for key asynchronously with optional TTL. - - Args: - key: Key to set - value: Value as bytes - ttl_seconds: Optional time-to-live in seconds - - Returns: - Coroutine that resolves to True if successful - """ - ... - - @staticmethod - def set(key: str, value: bytes, ttl_seconds: Optional[int] = None) -> bool: - """Set value for key synchronously with optional TTL. - - Args: - key: Key to set - value: Value as bytes - ttl_seconds: Optional time-to-live in seconds - - Returns: - True if successful - """ - ... - - @staticmethod - def adelete(key: str) -> Coroutine[None, None, int]: - """Delete key asynchronously. - - Args: - key: Key to delete - - Returns: - Coroutine that resolves to number of keys deleted (0 or 1) - """ - ... - - @staticmethod - def delete(key: str) -> int: - """Delete key synchronously. - - Args: - key: Key to delete - - Returns: - Number of keys deleted (0 or 1) - """ - ... - - # Counter operations - @staticmethod - def incr(key: str) -> int: - """Increment a counter atomically. - - Args: - key: Counter key - - Returns: - Value after increment - """ - ... - - @staticmethod - def decr(key: str) -> int: - """Decrement a counter atomically. - - Args: - key: Counter key - - Returns: - Value after decrement - """ - ... - - # Pub/sub operations - @staticmethod - def publish(channel: str, message: str) -> None: - """Publish message to a channel. - - Args: - channel: Channel name - message: Message to publish - """ - ... - - @staticmethod - def subscribe(channel: str) -> AsyncGenerator[str, None]: - """Subscribe to a channel and yield messages. - - Args: - channel: Channel name - - Yields: - Messages from the channel - """ - ... - - # Key formatting - @staticmethod - def format_key(*args: str) -> str: - """Format a key by joining parts in a backend-specific way. - - Args: - *args: Parts of the key to join - - Returns: - Formatted key string - """ - ... - - # Serialization - @staticmethod - def serialize(obj: BaseModel) -> bytes: - """Serialize a Pydantic BaseModel to bytes. - - Stores the fully qualified class name alongside the data to enable - automatic type reconstruction during deserialization. - - Args: - obj: Pydantic BaseModel instance to serialize - - Returns: - Serialized bytes - - Raises: - TypeError: If obj is not a BaseModel instance - """ - ... - - @staticmethod - def deserialize(data: bytes) -> BaseModel: - """Deserialize bytes back to a Pydantic BaseModel instance. - - Uses the stored class information to dynamically import and reconstruct - the original type. - - Args: - data: Serialized bytes - - Returns: - Deserialized BaseModel instance - - Raises: - ImportError: If the class module cannot be imported - AttributeError: If the class does not exist in the module - ValueError: If the data is invalid - """ - ... - - # Lock operations - @staticmethod - async def acquire_lock(key: str, value: str, ttl_seconds: int) -> bool: - """Attempt to acquire a distributed lock. - - Uses atomic set-if-not-exists with TTL. The TTL is a safety net - for process death — locks should always be explicitly released - via release_lock() on task completion or error. - - Args: - key: Lock key - value: Lock value (typically agent_id for debugging) - ttl_seconds: Lock expiry in seconds (safety net for dead processes) - - Returns: - True if lock was acquired, False if already held - """ - ... - - @staticmethod - async def release_lock(key: str) -> int: - """Release a distributed lock. - - Args: - key: Lock key to release - - Returns: - Number of keys deleted (0 or 1) - """ - ... - - # Sorted set operations - @staticmethod - def zadd(key: str, mapping: dict[str, float]) -> int: - """Add members to a sorted set with scores. - - Args: - key: Sorted set key - mapping: Dict of {member: score} - - Returns: - Number of new members added - """ - ... - - @staticmethod - async def zrangebyscore( - key: str, min_score: float, max_score: float - ) -> list[bytes]: - """Get members with scores between min and max. - - Args: - key: Sorted set key - min_score: Minimum score (inclusive) - max_score: Maximum score (inclusive) - - Returns: - List of members as bytes - """ - ... - - @staticmethod - def zrem(key: str, *members: str) -> int: - """Remove members from a sorted set. - - Args: - key: Sorted set key - *members: Members to remove - - Returns: - Number of members removed - """ - ... - - # Cleanup operations - @staticmethod - def clear_keys() -> int: - """Clear all keys managed by this application. - - Only deletes keys that match the configured prefix and queue name. - This is useful during shutdown to prevent stale tasks from being - picked up on restart. - - Returns: - Total number of keys deleted - """ - ... - - -def load_backend(module: ModuleType) -> StateBackend: - """Load and validate a backend module conforms to StateBackend protocol. - - Uses the Protocol's __protocol_attrs__ to determine required methods. - - Args: - module: Backend module to validate - - Returns: - The module typed as StateBackend - - Raises: - TypeError: If the module is missing required functions - """ - required: frozenset[str] = getattr(StateBackend, "__protocol_attrs__") - missing = [name for name in required if not hasattr(module, name)] - if missing: - raise TypeError( - f"Backend module '{module.__name__}' missing required functions: {missing}" - ) - - return module # type: ignore[return-value] diff --git a/src/agentexec/state/base.py b/src/agentexec/state/base.py new file mode 100644 index 0000000..4330a61 --- /dev/null +++ b/src/agentexec/state/base.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +import importlib +import json +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Optional, TypedDict +from pydantic import BaseModel + +if TYPE_CHECKING: + from agentexec.schedule import ScheduledTask + + +class _SerializeWrapper(TypedDict): + __type__: str + data: dict[str, Any] + + +class BaseBackend(ABC): + """Top-level backend interface with namespaced sub-backends.""" + + state: BaseStateBackend + queue: BaseQueueBackend + schedule: BaseScheduleBackend + + @abstractmethod + def format_key(self, *args: str) -> str: ... + + @abstractmethod + async def close(self) -> None: ... + + def serialize(self, obj: BaseModel) -> bytes: + """Serialize a Pydantic model to bytes with type information.""" + wrapper: _SerializeWrapper = { + "__type__": f"{type(obj).__module__}.{type(obj).__qualname__}", + "data": obj.model_dump(mode="json"), + } + return json.dumps(wrapper).encode("utf-8") + + def deserialize(self, data: bytes) -> BaseModel: + """Deserialize bytes back to a typed Pydantic model.""" + wrapper: _SerializeWrapper = json.loads(data.decode("utf-8")) + module_path, class_name = wrapper["__type__"].rsplit(".", 1) + module = importlib.import_module(module_path) + cls = getattr(module, class_name) + return cls.model_validate(wrapper["data"]) + + +class BaseStateBackend(ABC): + """KV store, counters, locks, pub/sub, sorted index.""" + + @abstractmethod + async def get(self, key: str) -> Optional[bytes]: ... + + @abstractmethod + async def set(self, key: str, value: bytes, ttl_seconds: Optional[int] = None) -> bool: ... + + @abstractmethod + async def delete(self, key: str) -> int: ... + + @abstractmethod + async def counter_incr(self, key: str) -> int: ... + + @abstractmethod + async def counter_decr(self, key: str) -> int: ... + + +class BaseQueueBackend(ABC): + """Task queue with push/pop semantics and partition-level locking.""" + + @abstractmethod + async def push( + self, + value: str, + *, + high_priority: bool = False, + partition_key: str | None = None, + ) -> None: ... + + @abstractmethod + async def pop(self, *, timeout: int = 1) -> dict[str, Any] | None: ... + + @abstractmethod + async def complete(self, partition_key: str | None) -> None: + """Signal that the current task for this partition is done.""" + ... + + + +class BaseScheduleBackend(ABC): + """Schedule storage and retrieval.""" + + @abstractmethod + async def register(self, task: ScheduledTask) -> None: + """Store a scheduled task definition.""" + ... + + @abstractmethod + async def get_due(self) -> list[ScheduledTask]: + """Return all scheduled tasks that are due to fire.""" + ... + + @abstractmethod + async def remove(self, key: str) -> None: + """Remove a schedule by its key.""" + ... diff --git a/src/agentexec/state/kafka.py b/src/agentexec/state/kafka.py new file mode 100644 index 0000000..5541d66 --- /dev/null +++ b/src/agentexec/state/kafka.py @@ -0,0 +1,300 @@ +from __future__ import annotations + +import asyncio +import json +import os +import socket +import time +from typing import Any, Optional + +from aiokafka import AIOKafkaConsumer, AIOKafkaProducer, TopicPartition +from aiokafka.admin import AIOKafkaAdminClient, NewTopic + +from agentexec.config import CONF +from agentexec.state.base import BaseBackend, BaseQueueBackend, BaseScheduleBackend, BaseStateBackend + + + +class Backend(BaseBackend): + """Kafka implementation of the agentexec backend.""" + + def __init__(self) -> None: + self._producer: AIOKafkaProducer | None = None + self._consumers: dict[str, AIOKafkaConsumer] = {} + self._admin: AIOKafkaAdminClient | None = None + + self._initialized_topics: set[str] = set() + + # Sub-backends + self.state = KafkaStateBackend(self) + self.queue = KafkaQueueBackend(self) + self.schedule = KafkaScheduleBackend(self) + + def format_key(self, *args: str) -> str: + return ".".join(args) + + async def close(self) -> None: + if self._producer is not None: + await self._producer.stop() + self._producer = None + + for consumer in self._consumers.values(): + await consumer.stop() + self._consumers.clear() + + if self._admin is not None: + await self._admin.close() + self._admin = None + + def _get_bootstrap_servers(self) -> str: + if CONF.kafka_bootstrap_servers is None: + raise ValueError( + "KAFKA_BOOTSTRAP_SERVERS must be configured " + "(e.g. 'localhost:9092' or 'broker1:9092,broker2:9092')" + ) + return CONF.kafka_bootstrap_servers + + def _client_id(self, role: str = "worker") -> str: + return f"{CONF.key_prefix}-{role}-{socket.gethostname()}-{os.getpid()}" + + async def _get_producer(self) -> AIOKafkaProducer: + if self._producer is None: + self._producer = AIOKafkaProducer( + bootstrap_servers=self._get_bootstrap_servers(), + client_id=self._client_id("producer"), + acks="all", + max_batch_size=CONF.kafka_max_batch_size, + linger_ms=CONF.kafka_linger_ms, + ) + await self._producer.start() + return self._producer + + async def _get_admin(self) -> AIOKafkaAdminClient: + if self._admin is None: + self._admin = AIOKafkaAdminClient( + bootstrap_servers=self._get_bootstrap_servers(), + client_id=self._client_id("admin"), + ) + await self._admin.start() + return self._admin + + async def produce( + self, + topic: str, + value: bytes | None, + key: str | bytes | None = None, + headers: dict[str, str] | None = None, + ) -> None: + producer = await self._get_producer() + if isinstance(key, str): + key_bytes = key.encode("utf-8") + else: + key_bytes = key + header_list = [(k, v.encode("utf-8")) for k, v in headers.items()] if headers else None + await producer.send_and_wait(topic, value=value, key=key_bytes, headers=header_list) + + async def ensure_topic(self, topic: str, *, compact: bool = True) -> None: + if topic in self._initialized_topics: + return + + admin = await self._get_admin() + config: dict[str, str] = {} + if compact: + config["cleanup.policy"] = "compact" + config["retention.ms"] = str(CONF.kafka_retention_ms) + + try: + await admin.create_topics( + [ + NewTopic( + name=topic, + num_partitions=CONF.kafka_default_partitions, + replication_factor=CONF.kafka_replication_factor, + topic_configs=config, + ) + ] + ) + except Exception: + pass # Topic already exists + + self._initialized_topics.add(topic) + + async def _get_topic_partitions(self, topic: str) -> list[TopicPartition]: + admin = await self._get_admin() + topics_meta = await admin.describe_topics([topic]) + for t in topics_meta: + if t.get("topic") == topic: + parts = t.get("partitions", []) + if parts: + return [ + TopicPartition(topic, p["partition"]) + for p in sorted(parts, key=lambda p: p["partition"]) + ] + return [TopicPartition(topic, 0)] + + def tasks_topic(self, queue_name: str) -> str: + return f"{CONF.key_prefix}.tasks.{queue_name}" + + def schedule_topic(self) -> str: + return f"{CONF.key_prefix}.schedules" + + +class KafkaStateBackend(BaseStateBackend): + """Kafka state: not supported. + + Kafka is not a key-value store. State operations (get/set, counters) + require a proper KV backend like Redis. Use Kafka for queue and + schedule only. + """ + + async def get(self, key: str) -> Optional[bytes]: + raise NotImplementedError("Kafka backend does not support KV state operations") + + async def set(self, key: str, value: bytes, ttl_seconds: Optional[int] = None) -> bool: + raise NotImplementedError("Kafka backend does not support KV state operations") + + async def delete(self, key: str) -> int: + raise NotImplementedError("Kafka backend does not support KV state operations") + + async def counter_incr(self, key: str) -> int: + raise NotImplementedError("Kafka backend does not support counter operations") + + async def counter_decr(self, key: str) -> int: + raise NotImplementedError("Kafka backend does not support counter operations") + + + +class KafkaQueueBackend(BaseQueueBackend): + """Kafka queue: consumer groups for reliable fan-out.""" + + def __init__(self, backend: Backend) -> None: + self.backend = backend + + async def _get_consumer(self, topic: str) -> AIOKafkaConsumer: + consumers = self.backend._consumers + + if topic not in consumers: + await self.backend.ensure_topic(topic, compact=False) + + consumer = AIOKafkaConsumer( + topic, + bootstrap_servers=self.backend._get_bootstrap_servers(), + group_id=f"{CONF.key_prefix}-workers", + client_id=self.backend._client_id("worker"), + auto_offset_reset="earliest", + enable_auto_commit=False, + ) + await consumer.start() + consumers[topic] = consumer + + return consumers[topic] + + async def push( + self, + queue_name: str, + value: str, + *, + high_priority: bool = False, + partition_key: str | None = None, + ) -> None: + topic = self.backend.tasks_topic(queue_name) + await self.backend.ensure_topic(topic, compact=False) + + # Extract metadata for headers without altering the payload + task_data = json.loads(value) + headers = { + "ax_task_name": task_data.get("task_name", ""), + "ax_agent_id": task_data.get("agent_id", ""), + "ax_retry_count": str(task_data.get("retry_count", 0)), + } + await self.backend.produce(topic, value.encode("utf-8"), key=partition_key, headers=headers) + + async def pop( + self, + queue_name: str, + *, + timeout: int = 1, + ) -> dict[str, Any] | None: + consumer = await self._get_consumer(self.backend.tasks_topic(queue_name)) + + try: + msg = await asyncio.wait_for( + consumer.getone(), + timeout=timeout, + ) + await consumer.commit() + return json.loads(msg.value.decode("utf-8")) + except asyncio.TimeoutError: + return None + + async def complete(self, partition_key: str | None) -> None: + pass # Kafka uses partition assignment, no explicit locks + + +class KafkaScheduleBackend(BaseScheduleBackend): + """Kafka schedule: compacted topic + in-memory cache.""" + + def __init__(self, backend: Backend) -> None: + self.backend = backend + self._consumer: AIOKafkaConsumer | None = None + self._tps: list[TopicPartition] = [] + + async def _ensure_consumer(self) -> AIOKafkaConsumer: + topic = self.backend.schedule_topic() + await self.backend.ensure_topic(topic) + + if self._consumer is None: + self._tps = await self.backend._get_topic_partitions(topic) + self._consumer = AIOKafkaConsumer( + bootstrap_servers=self.backend._get_bootstrap_servers(), + client_id=self.backend._client_id("scheduler"), + enable_auto_commit=False, + ) + await self._consumer.start() + self._consumer.assign(self._tps) + + return self._consumer + + async def register(self, task: ScheduledTask) -> None: + topic = self.backend.schedule_topic() + await self.backend.ensure_topic(topic) + data = task.model_dump_json().encode("utf-8") + headers = { + "ax_task_name": task.task_name, + "ax_cron": task.cron, + "ax_next_run": str(task.next_run), + "ax_repeat": str(task.repeat), + } + await self.backend.produce(topic, data, key=task.key, headers=headers) + + async def get_due(self) -> list[ScheduledTask]: + # TODO: this replays the entire compacted topic on every poll — + # seek, iterate, deserialize, compare for each schedule. Consider + # caching with invalidation or using message timestamps to skip + # schedules that aren't close to due. + from agentexec.schedule import ScheduledTask + from pydantic import ValidationError + + consumer = await self._ensure_consumer() + await consumer.seek_to_beginning(*self._tps) + + now = time.time() + due = [] + records = await consumer.getmany(*self._tps, timeout_ms=1000) + for tp_records in records.values(): + for msg in tp_records: + if msg.value is None: + continue + try: + task = ScheduledTask.model_validate_json(msg.value) + if task.next_run <= now: + due.append(task) + except ValidationError: + continue + + return due + + async def remove(self, key: str) -> None: + topic = self.backend.schedule_topic() + await self.backend.ensure_topic(topic) + await self.backend.produce(topic, None, key=key) diff --git a/src/agentexec/state/redis.py b/src/agentexec/state/redis.py new file mode 100644 index 0000000..244416e --- /dev/null +++ b/src/agentexec/state/redis.py @@ -0,0 +1,264 @@ +"""Redis state backend. + +Provides queue, state (KV/counters/sorted sets), and schedule operations +backed by Redis. The queue implementation uses a partitioned design +inspired by Kafka's consumer groups: + +Queue Key Layout +~~~~~~~~~~~~~~~~ + +All queue keys share a common prefix (``CONF.queue_prefix``, default +``agentexec_tasks``):: + + agentexec_tasks ← default queue (no lock, concurrent) + agentexec_tasks:user:42 ← partition queue for lock scope "user:42" + agentexec_tasks:user:42:lock ← lock for that partition (SET NX EX) + +Tasks without a ``lock_key`` go to the default queue, where any worker can +pop them concurrently. Tasks with a ``lock_key`` (evaluated from the +``TaskDefinition.lock_key`` template against the task context) go to a +partition queue keyed by that value. + +Dequeue Strategy +~~~~~~~~~~~~~~~~ + +Workers call ``queue.pop()`` which uses Redis SCAN to iterate all keys +matching the queue prefix. SCAN returns keys in hash-table order, which +is effectively random — providing fair distribution across partitions +without explicit shuffling. + +For each key discovered: + +1. If it ends with ``:lock``, record it in ``locks_seen`` and skip. +2. If it's a partition queue (not the default), check ``locks_seen`` for + an existing lock. If found, skip. Otherwise attempt ``SET NX EX`` to + acquire the lock. If acquisition fails, skip. +3. ``RPOP`` the queue key. If successful, return the task payload. +4. On task completion, the pool calls ``queue.complete(partition_key)`` + which deletes the lock key, allowing the next task in that partition + to be picked up. + +Redis automatically deletes list keys when they become empty, so drained +partitions disappear from future scans. Lock keys expire via TTL as a +safety net for dead worker recovery. +""" + +from __future__ import annotations + +import uuid +from typing import TYPE_CHECKING, Any, Optional + +import redis +import redis.asyncio + +from agentexec.config import CONF +from agentexec.state.base import BaseBackend, BaseQueueBackend, BaseScheduleBackend, BaseStateBackend + + +class Backend(BaseBackend): + """Redis implementation of the agentexec backend.""" + + _client: redis.asyncio.Redis | None + state: RedisStateBackend + queue: RedisQueueBackend + schedule: RedisScheduleBackend + + def __init__(self) -> None: + self._client = None + self.state = RedisStateBackend(self) + self.queue = RedisQueueBackend(self) + self.schedule = RedisScheduleBackend(self) + + def format_key(self, *args: str) -> str: + return ":".join(args) + + async def close(self) -> None: + if self._client is not None: + await self._client.aclose() + self._client = None + + @property + def client(self) -> redis.asyncio.Redis: + if self._client is None: + if CONF.redis_url is None: + raise ValueError("REDIS_URL must be configured") + self._client = redis.asyncio.Redis.from_url( + CONF.redis_url, + max_connections=CONF.redis_pool_size, + socket_connect_timeout=CONF.redis_pool_timeout, + decode_responses=False, + ) + return self._client + + +class RedisStateBackend(BaseStateBackend): + """Redis state: direct Redis commands.""" + + backend: Backend + + def __init__(self, backend: Backend) -> None: + self.backend = backend + + async def get(self, key: str) -> Optional[bytes]: + return await self.backend.client.get(key) # type: ignore[return-value] + + async def set(self, key: str, value: bytes, ttl_seconds: Optional[int] = None) -> bool: + if ttl_seconds is not None: + return await self.backend.client.set(key, value, ex=ttl_seconds) # type: ignore[return-value] + else: + return await self.backend.client.set(key, value) # type: ignore[return-value] + + async def delete(self, key: str) -> int: + return await self.backend.client.delete(key) # type: ignore[return-value] + + async def counter_incr(self, key: str) -> int: + return await self.backend.client.incr(key) # type: ignore[return-value] + + async def counter_decr(self, key: str) -> int: + return await self.backend.client.decr(key) # type: ignore[return-value] + + +class RedisQueueBackend(BaseQueueBackend): + """Redis queue: partitioned lists with per-group locking. + + Tasks with a partition_key go to {prefix}:{partition_key} and are + serialized by a lock. Tasks without a partition_key go to the + default queue ({prefix}) and execute concurrently. + """ + + backend: Backend + _lock_suffix: bytes = b":lock" + _prefix: str + _default_key: bytes + + def __init__(self, backend: Backend) -> None: + self.backend = backend + self._prefix = CONF.queue_prefix + self._default_key = self._prefix.encode() + + def _queue_key(self, partition_key: str | None = None) -> str: + if partition_key: + return f"{self._prefix}:{partition_key}" + return self._prefix + + def _lock_key(self, queue_key: bytes) -> bytes: + return queue_key + self._lock_suffix + + def _needs_lock(self, queue_key: bytes) -> bool: + return queue_key != self._default_key + + async def _acquire_lock(self, queue_key: bytes) -> bool: + return bool(await self.backend.client.set( + self._lock_key(queue_key), b"1", nx=True, ex=CONF.lock_ttl, + )) + + async def push( + self, + value: str, + *, + high_priority: bool = False, + partition_key: str | None = None, + ) -> None: + """Push a task to the queue. + + Tasks with a ``partition_key`` go to a dedicated partition queue + and are serialized by a lock. Tasks without one go to the default + queue for concurrent processing. + """ + key = self._queue_key(partition_key) + if high_priority: + await self.backend.client.rpush(key, value) + else: + await self.backend.client.lpush(key, value) + + async def pop(self, *, timeout: int = 1) -> dict[str, Any] | None: + """Pop the next eligible task from any queue. + + Scans all queue keys, skips locked partitions, acquires a lock + for the selected partition, and pops the task. Returns ``None`` + if no eligible tasks are available. + """ + import json + + locks_seen: set[bytes] = set() + + # SCAN returns keys in hash-table order (effectively random), + # so we don't need to collect all keys before choosing. + # We try each key eagerly and exit on the first successful pop. + async for key in self.backend.client.scan_iter(match=self._prefix.encode() + b"*", count=100): + if self._needs_lock(key): + if key.endswith(self._lock_suffix): + locks_seen.add(key) + continue # this is a lock record, not executable + + if self._lock_key(key) in locks_seen: + continue # we already observed another worker holds this partition, find another + + if not await self._acquire_lock(key): + continue # another worker holds this partition, find another + + result = await self.backend.client.rpop(key) + if result is None: + if self._needs_lock(key): + # TODO this should never happen; we can improve on the ergonomics of recovery later. + raise RuntimeError(f"Partition queue {key!r} was empty after lock acquired") + + continue # payload was grabbed in a race condition, find another + + return json.loads(result) + + async def complete(self, partition_key: str | None) -> None: + """Signal that the current task for this partition is done. + + Deletes the partition lock so the next task in the same scope + can be picked up. No-op for tasks without a partition key. + """ + if partition_key: + await self.backend.client.delete(self._lock_key(self._queue_key(partition_key).encode())) + + +class RedisScheduleBackend(BaseScheduleBackend): + """Redis schedule: sorted set for time index + hash for payloads. + + Two Redis keys:: + + agentexec:schedules ← sorted set (schedule.key → next_run score) + agentexec:schedules:data ← hash (schedule.key → task JSON) + + ``get_due`` queries the sorted set for keys with score <= now, + then batch-fetches the payloads from the hash. + """ + + backend: Backend + _index_key: str + _data_key: str + + def __init__(self, backend: Backend) -> None: + self.backend = backend + self._index_key = self.backend.format_key(CONF.key_prefix, "schedules") + self._data_key = self.backend.format_key(CONF.key_prefix, "schedules", "data") + + async def register(self, task: ScheduledTask) -> None: + await self.backend.client.hset(self._data_key, task.key, task.model_dump_json().encode()) + await self.backend.client.zadd(self._index_key, {task.key: task.next_run}) + + async def get_due(self) -> list[ScheduledTask]: + import time + from pydantic import ValidationError + from agentexec.schedule import ScheduledTask + + raw = await self.backend.client.zrangebyscore(self._index_key, 0, time.time()) + tasks = [] + for key in raw: + data = await self.backend.client.hget(self._data_key, key) + if data is None: + continue + try: + tasks.append(ScheduledTask.model_validate_json(data)) + except ValidationError: + continue + return tasks + + async def remove(self, key: str) -> None: + await self.backend.client.zrem(self._index_key, key) + await self.backend.client.hdel(self._data_key, key) diff --git a/src/agentexec/state/redis_backend.py b/src/agentexec/state/redis_backend.py deleted file mode 100644 index d7c8dba..0000000 --- a/src/agentexec/state/redis_backend.py +++ /dev/null @@ -1,491 +0,0 @@ -# cspell:ignore rpush lpush brpop RPUSH LPUSH BRPOP -from typing import TypedDict, AsyncGenerator, Coroutine, Optional -import importlib -import json - -import redis -import redis.asyncio -from pydantic import BaseModel - -from agentexec.config import CONF - -__all__ = [ - "format_key", - "serialize", - "deserialize", - "rpush", - "lpush", - "brpop", - "aget", - "get", - "aset", - "set", - "adelete", - "delete", - "incr", - "decr", - "publish", - "subscribe", - "close", - "zadd", - "zrangebyscore", - "zrem", - "clear_keys", -] - -_redis_client: redis.asyncio.Redis | None = None -_redis_sync_client: redis.Redis | None = None -_pubsub: redis.asyncio.client.PubSub | None = None - - -def format_key(*args: str) -> str: - """Format a Redis key by joining parts with colons. - - Args: - *args: Parts of the key - - Returns: - Formatted key string - """ - return ":".join(args) - - -class SerializeWrapper(TypedDict): - __class__: str - __data__: str - - -def serialize(obj: BaseModel) -> bytes: - """Serialize a Pydantic BaseModel to JSON bytes with type information. - - Stores the fully qualified class name alongside the data, similar to pickle. - This allows deserialization without needing an external type registry. - - Args: - obj: Pydantic BaseModel instance to serialize - - Returns: - JSON-encoded bytes containing class info and data - - Raises: - TypeError: If obj is not a BaseModel instance - """ - if not isinstance(obj, BaseModel): - raise TypeError(f"Expected BaseModel, got {type(obj)}") - - cls = type(obj) - wrapper: SerializeWrapper = { - "__class__": f"{cls.__module__}.{cls.__qualname__}", - "__data__": obj.model_dump_json(), - } - - return json.dumps(wrapper).encode("utf-8") - - -def deserialize(data: bytes) -> BaseModel: - """Deserialize JSON bytes back to a Pydantic BaseModel instance. - - Uses the stored class information to dynamically import and reconstruct - the original type, similar to pickle. - - Args: - data: JSON-encoded bytes containing class info and data - - Returns: - Deserialized BaseModel instance - - Raises: - ImportError: If the class module cannot be imported - AttributeError: If the class does not exist in the module - ValueError: If the data is invalid JSON or missing required fields - """ - wrapper: SerializeWrapper = json.loads(data.decode("utf-8")) - class_path = wrapper["__class__"] - json_data = wrapper["__data__"] - - # Import the class dynamically (e.g., "myapp.models.Result" → myapp.models module) - module_path, class_name = class_path.rsplit(".", 1) - module = importlib.import_module(module_path) - cls = getattr(module, class_name) - - result: BaseModel = cls.model_validate_json(json_data) - return result - - -def _get_async_client() -> redis.asyncio.Redis: - """Get async Redis client, initializing lazily if needed. - - Returns: - Async Redis client instance - - Raises: - ValueError: If REDIS_URL is not configured - """ - global _redis_client - - if _redis_client is None: - if CONF.redis_url is None: - raise ValueError("REDIS_URL must be configured") - - _redis_client = redis.asyncio.Redis.from_url( - CONF.redis_url, - max_connections=CONF.redis_pool_size, - socket_connect_timeout=CONF.redis_pool_timeout, - decode_responses=False, # Handle binary data (pickled results) - ) - - return _redis_client - - -def _get_sync_client() -> redis.Redis: - """Get sync Redis client, initializing lazily if needed. - - Returns: - Sync Redis client instance - - Raises: - ValueError: If REDIS_URL is not configured - """ - global _redis_sync_client - - if _redis_sync_client is None: - if CONF.redis_url is None: - raise ValueError("REDIS_URL must be configured") - - _redis_sync_client = redis.Redis.from_url( - CONF.redis_url, - max_connections=CONF.redis_pool_size, - socket_connect_timeout=CONF.redis_pool_timeout, - decode_responses=False, - ) - - return _redis_sync_client - - -async def close() -> None: - """Close all Redis connections and clean up resources.""" - global _redis_client, _redis_sync_client, _pubsub - - # Close pubsub if active - if _pubsub is not None: - await _pubsub.close() - _pubsub = None - - # Close async client - if _redis_client is not None: - await _redis_client.aclose() - _redis_client = None - - # Close sync client - if _redis_sync_client is not None: - _redis_sync_client.close() - _redis_sync_client = None - - -def rpush(key: str, value: str) -> int: - """Push value to the right (front) of the list - for high priority tasks. - - Args: - key: Redis list key - value: Serialized task data - - Returns: - Length of the list after the push - """ - client = _get_sync_client() - return client.rpush(key, value) # type: ignore[return-value] - - -def lpush(key: str, value: str) -> int: - """Push value to the left (back) of the list - for low priority tasks. - - Args: - key: Redis list key - value: Serialized task data - - Returns: - Length of the list after the push - """ - client = _get_sync_client() - return client.lpush(key, value) # type: ignore[return-value] - - -async def brpop(key: str, timeout: int = 0) -> Optional[tuple[str, str]]: - """Pop value from the right of the list with blocking. - - Args: - key: Redis list key - timeout: Timeout in seconds (0 = block forever) - - Returns: - Tuple of (key, value) or None if timeout - """ - client = _get_async_client() - result = await client.brpop([key], timeout=timeout) # type: ignore[misc] - if result is None: - return None - # Redis returns bytes, decode to string - list_key, value = result - return (list_key.decode("utf-8"), value.decode("utf-8")) - - -def aget(key: str) -> Coroutine[None, None, Optional[bytes]]: - """Get value for key asynchronously. - - Args: - key: Key to retrieve - - Returns: - Coroutine that resolves to value as bytes or None if not found - """ - client = _get_async_client() - return client.get(key) # type: ignore[return-value] - - -def get(key: str) -> Optional[bytes]: - """Get value for key synchronously. - - Args: - key: Key to retrieve - - Returns: - Value as bytes or None if not found - """ - client = _get_sync_client() - return client.get(key) # type: ignore[return-value] - - -def aset(key: str, value: bytes, ttl_seconds: Optional[int] = None) -> Coroutine[None, None, bool]: - """Set value for key asynchronously with optional TTL. - - Args: - key: Key to set - value: Value as bytes - ttl_seconds: Optional time-to-live in seconds - - Returns: - Coroutine that resolves to True if successful - """ - client = _get_async_client() - if ttl_seconds is not None: - return client.set(key, value, ex=ttl_seconds) # type: ignore[return-value] - else: - return client.set(key, value) # type: ignore[return-value] - - -def set(key: str, value: bytes, ttl_seconds: Optional[int] = None) -> bool: - """Set value for key synchronously with optional TTL. - - Args: - key: Key to set - value: Value as bytes - ttl_seconds: Optional time-to-live in seconds - - Returns: - True if successful - """ - client = _get_sync_client() - if ttl_seconds is not None: - return client.set(key, value, ex=ttl_seconds) # type: ignore[return-value] - else: - return client.set(key, value) # type: ignore[return-value] - - -def adelete(key: str) -> Coroutine[None, None, int]: - """Delete key asynchronously. - - Args: - key: Key to delete - - Returns: - Coroutine that resolves to number of keys deleted (0 or 1) - """ - client = _get_async_client() - return client.delete(key) # type: ignore[return-value] - - -def delete(key: str) -> int: - """Delete key synchronously. - - Args: - key: Key to delete - - Returns: - Number of keys deleted (0 or 1) - """ - client = _get_sync_client() - return client.delete(key) # type: ignore[return-value] - - -def incr(key: str) -> int: - """Increment a counter atomically. - - Args: - key: Counter key - - Returns: - Value after increment - """ - client = _get_sync_client() - return client.incr(key) # type: ignore[return-value] - - -def decr(key: str) -> int: - """Decrement a counter atomically. - - Args: - key: Counter key - - Returns: - Value after decrement - """ - client = _get_sync_client() - return client.decr(key) # type: ignore[return-value] - - -async def acquire_lock(key: str, value: str, ttl_seconds: int) -> bool: - """Attempt to acquire a distributed lock using SET NX EX. - - Args: - key: Lock key - value: Lock value (typically agent_id for debugging) - ttl_seconds: Lock expiry in seconds (safety net for dead processes) - - Returns: - True if lock was acquired, False if already held - """ - client = _get_async_client() - result = await client.set(key, value, nx=True, ex=ttl_seconds) - return result is not None - - -async def release_lock(key: str) -> int: - """Release a distributed lock. - - Args: - key: Lock key to release - - Returns: - Number of keys deleted (0 or 1) - """ - client = _get_async_client() - return await client.delete(key) # type: ignore[return-value] - - -def publish(channel: str, message: str) -> None: - """Publish message to a channel. - - Args: - channel: Channel name - message: Message to publish - """ - client = _get_sync_client() - client.publish(channel, message) - - -async def subscribe(channel: str) -> AsyncGenerator[str, None]: - """Subscribe to a channel and yield messages. - - Args: - channel: Channel name - - Yields: - Messages from the channel as strings - """ - global _pubsub - - client = _get_async_client() - _pubsub = client.pubsub() - await _pubsub.subscribe(channel) - - try: - async for message in _pubsub.listen(): - if message["type"] == "message": - # Decode bytes to string - data = message["data"] - if isinstance(data, bytes): - yield data.decode("utf-8") - else: - yield data - finally: - await _pubsub.unsubscribe(channel) - await _pubsub.close() - _pubsub = None - - -def zadd(key: str, mapping: dict[str, float]) -> int: - """Add members to a sorted set with scores. - - Args: - key: Sorted set key - mapping: Dict of {member: score} - - Returns: - Number of new members added - """ - client = _get_sync_client() - return client.zadd(key, mapping) # type: ignore[return-value] - - -async def zrangebyscore( - key: str, min_score: float, max_score: float -) -> list[bytes]: - """Get members with scores between min and max. - - Args: - key: Sorted set key - min_score: Minimum score (inclusive) - max_score: Maximum score (inclusive) - - Returns: - List of members as bytes - """ - client = _get_async_client() - return await client.zrangebyscore(key, min_score, max_score) # type: ignore[return-value] - - -def zrem(key: str, *members: str) -> int: - """Remove members from a sorted set. - - Args: - key: Sorted set key - *members: Members to remove - - Returns: - Number of members removed - """ - client = _get_sync_client() - return client.zrem(key, *members) # type: ignore[return-value] - - -def clear_keys() -> int: - """Clear all Redis keys managed by this application. - - Uses SCAN to safely iterate through keys without blocking Redis. - Only deletes keys that match the configured prefix and queue name. - - Returns: - Total number of keys deleted, or 0 if Redis is not configured - """ - if CONF.redis_url is None: - return 0 - - client = _get_sync_client() - deleted = 0 - - # Delete the task queue - deleted += client.delete(CONF.queue_name) - - # Scan and delete all keys matching the configured prefix - # Pattern: "agentexec:*" (or whatever key_prefix is configured) - pattern = f"{CONF.key_prefix}:*" - cursor = 0 - - while True: - cursor, keys = client.scan(cursor=cursor, match=pattern, count=100) - if keys: - deleted += client.delete(*keys) - if cursor == 0: - break - - return deleted diff --git a/src/agentexec/tracker.py b/src/agentexec/tracker.py index 26a4fa2..6de64f5 100644 --- a/src/agentexec/tracker.py +++ b/src/agentexec/tracker.py @@ -5,63 +5,44 @@ Example: tracker = ax.Tracker("research", batch_id) - tracker.incr() # Count the discovery process itself + await tracker.incr() # Count the discovery process itself @function_tool async def queue_research(company: str) -> str: - tracker.incr() + await tracker.incr() await ax.enqueue("research", ResearchContext(company=company, batch_id=batch_id)) return f"Queued {company}" # When discovery finishes, decrement itself - if tracker.decr() == 0: + if await tracker.decr() == 0: await ax.enqueue("aggregate", AggregateContext(batch_id=batch_id)) # In research task - decrement when done tracker = ax.Tracker("research", context.batch_id) # ... do research ... - if tracker.decr() == 0: + if await tracker.decr() == 0: await ax.enqueue("aggregate", AggregateContext(batch_id=context.batch_id)) """ -from agentexec import state from agentexec.config import CONF +from agentexec.state import backend class Tracker: - """Coordinate dynamic fan-out with an atomic counter. - - Args: - *args: Key parts used to construct the tracker's unique key. - Typically includes a name and identifier, e.g., ("research", batch_id) - """ + """Coordinate dynamic fan-out with an atomic counter.""" def __init__(self, *args: str): - self._key = state.backend.format_key(CONF.key_prefix, "tracker", *args) - - def incr(self) -> int: - """Increment the counter. - - Returns: - Counter value after increment. - """ - return state.backend.incr(self._key) + self._key = backend.format_key(CONF.key_prefix, "tracker", *args) - def decr(self) -> int: - """Decrement the counter. + async def incr(self) -> int: + return await backend.state.counter_incr(self._key) - Returns: - Counter value after decrement. - """ - return state.backend.decr(self._key) + async def decr(self) -> int: + return await backend.state.counter_decr(self._key) - @property - def count(self) -> int: - """Get current counter value.""" - result = state.backend.get(self._key) + async def count(self) -> int: + result = await backend.state.get(self._key) return int(result) if result else 0 - @property - def complete(self) -> bool: - """Check if counter has reached zero.""" - return self.count == 0 + async def complete(self) -> bool: + return await self.count() == 0 diff --git a/src/agentexec/worker/event.py b/src/agentexec/worker/event.py index 7eede1e..797549d 100644 --- a/src/agentexec/worker/event.py +++ b/src/agentexec/worker/event.py @@ -1,5 +1,7 @@ from __future__ import annotations -from agentexec import state + +from agentexec.config import CONF +from agentexec.state import KEY_EVENT, backend class StateEvent: @@ -7,42 +9,23 @@ class StateEvent: Provides an interface similar to threading.Event/multiprocessing.Event, but backed by the state backend for cross-process and cross-machine coordination. - - This class is fully picklable (just stores name and optional id) and works - across any process that can connect to the same state backend. - - set() and clear() are synchronous for use from pool management code. - is_set() is async for use from worker event loops. - - Example: - event = StateEvent("shutdown", "pool1") - - # In pool (sync context) - event.set() - - # In worker (async context) - if await event.is_set(): - print("Shutdown signal received") """ def __init__(self, name: str, id: str) -> None: - """Initialize the event. - - Args: - name: Event name (e.g., "shutdown", "ready") - id: Identifier to scope the event (e.g., pool id) - """ self.name = name self.id = id - def set(self) -> None: + def _key(self) -> str: + return backend.format_key(*KEY_EVENT, self.name, self.id) + + async def set(self) -> None: """Set the event flag to True.""" - state.set_event(self.name, self.id) + await backend.state.set(self._key(), b"1") - def clear(self) -> None: + async def clear(self) -> None: """Reset the event flag to False.""" - state.clear_event(self.name, self.id) + await backend.state.delete(self._key()) async def is_set(self) -> bool: """Check if the event flag is True.""" - return await state.acheck_event(self.name, self.id) + return await backend.state.get(self._key()) is not None diff --git a/src/agentexec/worker/logging.py b/src/agentexec/worker/logging.py index acbb34c..3af7d5a 100644 --- a/src/agentexec/worker/logging.py +++ b/src/agentexec/worker/logging.py @@ -1,7 +1,7 @@ from __future__ import annotations import logging +import multiprocessing as mp from pydantic import BaseModel -from agentexec import state LOGGER_NAME = "agentexec" LOG_CHANNEL = "agentexec:logs" @@ -9,7 +9,7 @@ class LogMessage(BaseModel): - """Schema for log messages sent via state backend pubsub.""" + """Schema for log messages sent via the worker message queue.""" name: str levelno: int @@ -22,7 +22,6 @@ class LogMessage(BaseModel): @classmethod def from_log_record(cls, record: logging.LogRecord) -> LogMessage: - """Create a LogMessage from a logging.LogRecord.""" return cls( name=record.name, levelno=record.levelno, @@ -35,7 +34,6 @@ def from_log_record(cls, record: logging.LogRecord) -> LogMessage: ) def to_log_record(self) -> logging.LogRecord: - """Convert back to a logging.LogRecord.""" record = logging.LogRecord( name=self.name, level=self.levelno, @@ -51,21 +49,18 @@ def to_log_record(self) -> logging.LogRecord: return record -class StateLogHandler(logging.Handler): - """Logging handler that publishes log records to state backend pubsub. +class QueueLogHandler(logging.Handler): + """Logging handler that sends log records to the pool via multiprocessing queue.""" - Used by worker processes to send logs to the main process. - """ - - def __init__(self, channel: str = LOG_CHANNEL): + def __init__(self, tx: mp.Queue): super().__init__() - self.channel = channel + self.tx = tx def emit(self, record: logging.LogRecord) -> None: - """Publish log record to log channel.""" try: + from agentexec.worker.pool import LogEntry message = LogMessage.from_log_record(record) - state.publish_log(message.model_dump_json()) + self.tx.put_nowait(LogEntry(record=message)) except Exception: self.handleError(record) @@ -73,29 +68,14 @@ def emit(self, record: logging.LogRecord) -> None: _worker_logging_configured = False -def get_worker_logger(name: str) -> logging.Logger: - """Configure worker logging and return a logger. - - On first call, sets up a state handler that publishes log records - to the main process via state backend pubsub. Subsequent calls just return - a logger under the agentexec namespace. - - Args: - name: Logger name. Typically __name__. - - Returns: - Configured logger instance. - - Example: - logger = get_worker_logger(__name__) - logger.info("Worker starting") - """ +def get_worker_logger(name: str, tx: mp.Queue | None = None) -> logging.Logger: + """Configure worker logging and return a logger.""" global _worker_logging_configured - if not _worker_logging_configured: + if not _worker_logging_configured and tx is not None: root = logging.getLogger(LOGGER_NAME) root.setLevel(logging.INFO) - root.addHandler(StateLogHandler()) + root.addHandler(QueueLogHandler(tx)) root.propagate = False _worker_logging_configured = True diff --git a/src/agentexec/worker/pool.py b/src/agentexec/worker/pool.py index 8d4dedc..6aa495c 100644 --- a/src/agentexec/worker/pool.py +++ b/src/agentexec/worker/pool.py @@ -5,15 +5,21 @@ import multiprocessing as mp from dataclasses import dataclass from typing import Any, Callable -from uuid import uuid4 +from uuid import UUID, uuid4 from pydantic import BaseModel from sqlalchemy import Engine, create_engine +from sqlalchemy.orm import Session, sessionmaker -from agentexec import state from agentexec.config import CONF -from agentexec.core.db import remove_global_session, set_global_session -from agentexec.core.queue import dequeue, requeue +from agentexec.state import backend +import queue as stdlib_queue + +from agentexec import activity +from agentexec.activity.events import ActivityUpdated +from agentexec.activity.handlers import IPCHandler +from agentexec.core.db import configure_engine +from agentexec.core.queue import enqueue from agentexec.core.task import Task, TaskDefinition, TaskHandler from agentexec import schedule from agentexec.worker.event import StateEvent @@ -29,6 +35,24 @@ ] +class Message(BaseModel): + """Base event sent from a worker to the pool.""" + pass + + +class TaskFailed(Message): + task: Task + error: str + + @classmethod + def from_exception(cls, task: Task, exception: Exception) -> TaskFailed: + return cls(task=task, error=str(exception)) + + +class LogEntry(Message): + record: LogMessage + + class _EmptyContext(BaseModel): """Default context for scheduled tasks that don't need one.""" @@ -44,22 +68,21 @@ def _get_pool_id() -> str: class WorkerContext: """Shared context passed from Pool to Worker processes.""" - database_url: str shutdown_event: StateEvent tasks: dict[str, TaskDefinition] - queue_name: str + tx: mp.Queue class Worker: """Individual worker process with isolated state. Each worker configures the scoped Session factory on startup. - Task handlers can use get_global_session() to get the process-local session. + Workers don't have database access — all persistence goes through the pool. """ _worker_id: int _context: WorkerContext - _logger: logging.Logger + logger: logging.Logger def __init__(self, worker_id: int, context: WorkerContext): """Initialize worker with isolated state. @@ -70,7 +93,9 @@ def __init__(self, worker_id: int, context: WorkerContext): """ self._worker_id = worker_id self._context = context - self._logger = get_worker_logger(__name__) + self.logger = get_worker_logger(__name__, tx=context.tx) + + activity.handler = IPCHandler(context.tx) @classmethod def run_in_process(cls, worker_id: int, context: WorkerContext) -> None: @@ -85,68 +110,44 @@ def run_in_process(cls, worker_id: int, context: WorkerContext) -> None: def run(self) -> None: """Main worker entry point - sets up async loop and runs.""" - self._logger.info(f"Worker {self._worker_id} starting") - - engine = create_engine(self._context.database_url) - set_global_session(engine) + self.logger.info(f"Worker {self._worker_id} starting") try: asyncio.run(self._run()) except Exception as e: - self._logger.exception(f"Worker {self._worker_id} fatal error: {e}") + self.logger.exception(f"Worker {self._worker_id} fatal error: {e}") raise - - async def _run(self) -> None: - """Async main loop - polls queue and processes tasks.""" - try: - # No sleep needed - dequeue() uses brpop which blocks waiting for tasks - while not await self._context.shutdown_event.is_set(): - if (task := await self._dequeue_task()) is not None: - lock_key = task.get_lock_key() - - if lock_key is not None: - acquired = await state.acquire_lock(lock_key, str(task.agent_id)) - if not acquired: - self._logger.debug( - f"Worker {self._worker_id} lock held for {task.task_name} " - f"(lock_key={lock_key}), requeuing" - ) - requeue(task, queue_name=self._context.queue_name) - continue - - try: - self._logger.info(f"Worker {self._worker_id} processing: {task.task_name}") - await task.execute() - self._logger.info(f"Worker {self._worker_id} completed: {task.task_name}") - finally: - if lock_key is not None: - await state.release_lock(lock_key) - except Exception as e: - self._logger.exception(f"Worker {self._worker_id} error: {e}") - # Continue processing other tasks - # TODO allow configurable behavior here (retry, backoff, fail) - # TODO all of the actual logic is handled in task.execute(), so I don't know why we ever end up here. finally: - await state.backend.close() - remove_global_session() - self._logger.info(f"Worker {self._worker_id} shutting down") + asyncio.run(backend.close()) + self.logger.info(f"Worker {self._worker_id} shutting down") - async def _dequeue_task(self) -> Task | None: - """Dequeue and hydrate a task from the Redis queue. + def _send(self, message: Message) -> None: + """Send a message to the pool via the multiprocessing queue.""" + self._context.tx.put_nowait(message) - Reconstructs the typed context using the TaskDefinition - and binds the definition to the task. + async def _run(self) -> None: + """Async main loop - dequeue, execute, complete.""" + while not await self._context.shutdown_event.is_set(): + try: + data = await backend.queue.pop(timeout=1) + if data is None: + continue - Returns: - Hydrated Task instance if available, else None. - """ - if (data := await dequeue(queue_name=self._context.queue_name)) is not None: - return Task.from_serialized( - definition=self._context.tasks[data["task_name"]], - data=data, - ) + task = Task.model_validate(data) + definition = self._context.tasks[task.task_name] + partition_key = definition.get_lock_key(task.context) + + try: + self.logger.info(f"Worker {self._worker_id} processing: {task.task_name}") + await definition.execute(task) + self.logger.info(f"Worker {self._worker_id} completed: {task.task_name}") + except Exception as e: + self._send(TaskFailed.from_exception(task, e)) + finally: + await backend.queue.complete(partition_key) + except Exception as e: + self.logger.exception(f"Worker {self._worker_id} error: {e}") - return None class Pool: @@ -177,14 +178,12 @@ def __init__( self, engine: Engine | None = None, database_url: str | None = None, - queue_name: str | None = None, ) -> None: """Initialize the worker pool. Args: engine: SQLAlchemy engine (URL will be extracted for workers). database_url: Database URL string. Alternative to passing engine. - queue_name: Redis queue name. Defaults to CONF.queue_name. Raises: ValueError: If neither engine nor database_url is provided. @@ -194,16 +193,16 @@ def __init__( raise ValueError("Either engine or database_url must be provided") engine = engine or create_engine(database_url) # type: ignore[arg-type] - set_global_session(engine) - + configure_engine(engine) + self._worker_queue: mp.Queue = mp.Queue() self._context = WorkerContext( - database_url=database_url or engine.url.render_as_string(hide_password=False), shutdown_event=StateEvent("shutdown", _get_pool_id()), tasks={}, - queue_name=queue_name or CONF.queue_name, + tx=self._worker_queue, ) self._processes = [] self._log_handler = None + self._pending_schedules: list[dict[str, Any]] = [] def task( self, @@ -345,6 +344,9 @@ def add_schedule( ``pool.add_task()``. The scheduler loop runs automatically inside ``pool.run()`` — no extra setup needed. + Schedules are stored and registered with the backend when + ``start()`` is called. + Args: task_name: Name of a registered task. every: Schedule expression (cron syntax: min hour dom mon dow). @@ -366,58 +368,52 @@ def add_schedule( f"Use @pool.task() or pool.add_task() first." ) - schedule.register( + self._pending_schedules.append(dict( task_name=task_name, every=every, context=context, repeat=repeat, metadata=metadata, - ) + )) - def start(self) -> None: - """Start worker processes (non-blocking). + async def start(self) -> None: + """Start workers and run until they exit. - Spawns N worker processes that poll the Redis queue and execute - tasks from this pool's registry. Returns immediately. - - Workers log to Redis pubsub. Use run() if you want the main - process to collect and display those logs. + Spawns worker processes, forwards logs, and processes scheduled + tasks. This is the foreground entry point — it blocks until all + workers finish. Use ``run()`` for a daemonized version that + handles KeyboardInterrupt and cleanup. """ - # Clear any stale shutdown signal - self._context.shutdown_event.clear() + await self._context.shutdown_event.clear() - # Spawn workers BEFORE setting up log handler to avoid pickling issues - # (StreamHandler has a lock that can't be pickled) + # Spawn workers before log handler to avoid pickling issues self._spawn_workers() - # Set up log handler for receiving worker logs # TODO make this configurable self._log_handler = logging.StreamHandler() self._log_handler.setFormatter(logging.Formatter(DEFAULT_FORMAT)) - def run(self) -> None: - """Start workers and run log collector until interrupted. + await asyncio.gather( + self._process_worker_events(), + self._process_scheduled_tasks(), + ) - Spawns worker processes and runs an async event loop in the main - process that collects logs from workers via Redis pubsub. - The scheduler loop also runs automatically alongside the workers, - polling for due scheduled tasks and enqueuing them. + def run(self) -> None: + """Start workers in a managed event loop with graceful shutdown. - Blocks until all workers exit or KeyboardInterrupt, then shuts - down gracefully. + Calls ``start()`` inside ``asyncio.run()`` and handles + KeyboardInterrupt, shutdown, and connection cleanup. """ async def _loop() -> None: try: - await self._collect_logs() + await self.start() except asyncio.CancelledError: pass finally: - self.shutdown() - await state.backend.close() + await self.shutdown() try: - self.start() asyncio.run(_loop()) except KeyboardInterrupt: pass @@ -436,34 +432,66 @@ def _spawn_workers(self) -> None: self._processes.append(process) print(f"Started worker {worker_id} (PID: {process.pid})") - async def _collect_logs(self) -> None: - """Listen for log messages from workers and run scheduler ticks.""" + async def _process_scheduled_tasks(self) -> None: + """Register pending schedules, then poll for due tasks and enqueue them.""" + for _schedule in self._pending_schedules: + await schedule.register(**_schedule) + self._pending_schedules.clear() + + while any(p.is_alive() for p in self._processes): + await asyncio.sleep(CONF.scheduler_poll_interval) + + for scheduled_task in await backend.schedule.get_due(): + await enqueue( + scheduled_task.task_name, + context=backend.deserialize(scheduled_task.context), + metadata=scheduled_task.metadata, + ) + + if scheduled_task.repeat == 0: + await backend.schedule.remove(scheduled_task.key) + else: + scheduled_task.advance() + await backend.schedule.register(scheduled_task) + + def _partition_key_for(self, task: Task) -> str | None: + """Derive the partition/lock key for a task from its definition.""" + return self._context.tasks[task.task_name].get_lock_key(task.context) + + async def _process_worker_events(self) -> None: + """Handle all events from worker processes via multiprocessing queue.""" assert self._log_handler, "Log handler not initialized" - # Create task to subscribe to logs - log_task = asyncio.create_task(self._process_log_stream()) - - try: - # Poll worker processes and run scheduler - while any(p.is_alive() for p in self._processes): - await asyncio.sleep(0.1) - await schedule.tick() - finally: - log_task.cancel() + while any(p.is_alive() for p in self._processes): try: - await log_task - except asyncio.CancelledError: - pass - - async def _process_log_stream(self) -> None: - """Process log messages from the state backend.""" - assert self._log_handler, "Log handler not initialized" - - async for message in state.subscribe_logs(): - log_message = LogMessage.model_validate_json(message) - self._log_handler.emit(log_message.to_log_record()) - - def shutdown(self, timeout: int | None = None) -> None: + message = self._worker_queue.get_nowait() + except stdlib_queue.Empty: + await asyncio.sleep(0.05) + continue + + match message: + case LogEntry(record=record): + self._log_handler.emit(record.to_log_record()) + + case TaskFailed(task=task, error=error): + if task.retry_count < CONF.max_task_retries: + task.retry_count += 1 + await backend.queue.push( + task.model_dump_json(), + partition_key=self._partition_key_for(task), + high_priority=True, + ) + else: + # TODO incorporate this messaging into the ax.activity stream. + print( + f"Task {task.task_name} failed " + f"after {task.retry_count + 1} attempts, giving up: {error}" + ) + + case ActivityUpdated(): + activity.handler(message) + + async def shutdown(self, timeout: int | None = None) -> None: """Gracefully shutdown all worker processes. For use with start(). If using run(), shutdown is handled automatically. @@ -475,7 +503,7 @@ def shutdown(self, timeout: int | None = None) -> None: timeout = CONF.graceful_shutdown_timeout print("Shutting down worker pool") - self._context.shutdown_event.set() + await self._context.shutdown_event.set() for process in self._processes: process.join(timeout=timeout) @@ -485,4 +513,5 @@ def shutdown(self, timeout: int | None = None) -> None: process.join(timeout=5) self._processes.clear() + await backend.close() print("Worker pool shutdown complete") diff --git a/tests/test_activity_schemas.py b/tests/test_activity_schemas.py index 2addd3d..67120a6 100644 --- a/tests/test_activity_schemas.py +++ b/tests/test_activity_schemas.py @@ -1,5 +1,3 @@ -"""Test activity schema validation and computed fields.""" - import uuid from datetime import datetime, timedelta, UTC diff --git a/tests/test_activity_tracking.py b/tests/test_activity_tracking.py index ab1963f..b40e8a3 100644 --- a/tests/test_activity_tracking.py +++ b/tests/test_activity_tracking.py @@ -1,5 +1,3 @@ -"""Tests for activity tracking functionality.""" - import uuid import pytest @@ -8,21 +6,20 @@ from agentexec import activity from agentexec.activity.models import Activity, ActivityLog, Base, Status -from agentexec.activity.tracker import normalize_agent_id +from agentexec.activity import normalize_agent_id + @pytest.fixture def db_session(): """Set up an in-memory SQLite database for testing.""" - # Create engine and session factory (users manage their own) - engine = create_engine("sqlite:///:memory:", echo=False) - SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + from agentexec.core.db import configure_engine - # Create tables + engine = create_engine("sqlite:///:memory:", echo=False) Base.metadata.create_all(bind=engine) + configure_engine(engine) - # Provide a session for the test - session = SessionLocal() + session = sessionmaker(bind=engine)() try: yield session session.commit() @@ -32,11 +29,12 @@ def db_session(): finally: session.close() engine.dispose() + engine.dispose() -def test_create_activity(db_session: Session): +async def test_create_activity(db_session: Session): """Test creating a new activity record.""" - agent_id = activity.create( + agent_id = await activity.create( task_name="test_task", message="Task queued for testing", session=db_session, @@ -84,17 +82,17 @@ def test_database_tables_created(): engine.dispose() -def test_update_activity(db_session: Session): +async def test_update_activity(db_session: Session): """Test updating an activity with a new log message.""" # First create an activity - agent_id = activity.create( + agent_id = await activity.create( task_name="test_task", message="Initial message", session=db_session, ) # Update the activity - result = activity.update( + result = await activity.update( agent_id=agent_id, message="Processing...", percentage=50, @@ -112,15 +110,15 @@ def test_update_activity(db_session: Session): assert activity_record.logs[1].percentage == 50 -def test_update_activity_with_custom_status(db_session: Session): +async def test_update_activity_with_custom_status(db_session: Session): """Test updating an activity with a custom status.""" - agent_id = activity.create( + agent_id = await activity.create( task_name="test_task", message="Initial", session=db_session, ) - activity.update( + await activity.update( agent_id=agent_id, message="Custom status update", status=Status.RUNNING, @@ -134,15 +132,15 @@ def test_update_activity_with_custom_status(db_session: Session): assert latest_log.status == Status.RUNNING -def test_complete_activity(db_session: Session): +async def test_complete_activity(db_session: Session): """Test marking an activity as complete.""" - agent_id = activity.create( + agent_id = await activity.create( task_name="test_task", message="Started", session=db_session, ) - result = activity.complete( + result = await activity.complete( agent_id=agent_id, message="Successfully completed", session=db_session, @@ -158,15 +156,15 @@ def test_complete_activity(db_session: Session): assert latest_log.percentage == 100 -def test_complete_activity_custom_percentage(db_session: Session): +async def test_complete_activity_custom_percentage(db_session: Session): """Test marking an activity complete with custom percentage.""" - agent_id = activity.create( + agent_id = await activity.create( task_name="test_task", message="Started", session=db_session, ) - activity.complete( + await activity.complete( agent_id=agent_id, message="Done", percentage=95, @@ -179,15 +177,15 @@ def test_complete_activity_custom_percentage(db_session: Session): assert latest_log.percentage == 95 -def test_error_activity(db_session: Session): +async def test_error_activity(db_session: Session): """Test marking an activity as errored.""" - agent_id = activity.create( + agent_id = await activity.create( task_name="test_task", message="Started", session=db_session, ) - result = activity.error( + result = await activity.error( agent_id=agent_id, message="Task failed: connection timeout", session=db_session, @@ -203,36 +201,36 @@ def test_error_activity(db_session: Session): assert latest_log.percentage == 100 -def test_cancel_pending_activities(db_session: Session): +async def test_cancel_pending_activities(db_session: Session): """Test canceling all pending activities.""" # Create some activities in different states - queued_id = activity.create( + queued_id = await activity.create( task_name="queued_task", message="Waiting", session=db_session, ) - running_id = activity.create( + running_id = await activity.create( task_name="running_task", message="Started", session=db_session, ) - activity.update( + await activity.update( agent_id=running_id, message="Running...", status=Status.RUNNING, session=db_session, ) - complete_id = activity.create( + complete_id = await activity.create( task_name="complete_task", message="Started", session=db_session, ) - activity.complete(agent_id=complete_id, session=db_session) + await activity.complete(agent_id=complete_id, session=db_session) # Cancel pending activities - canceled_count = activity.cancel_pending(session=db_session) + canceled_count = await activity.cancel_pending(session=db_session) # Should have canceled the queued and running activities assert canceled_count == 2 @@ -250,18 +248,18 @@ def test_cancel_pending_activities(db_session: Session): assert complete_record.logs[-1].status == Status.COMPLETE # Not changed -def test_list_activities(db_session: Session): +async def test_list_activities(db_session: Session): """Test listing activities with pagination.""" # Create several activities for i in range(5): - activity.create( + await activity.create( task_name=f"task_{i}", message=f"Message {i}", session=db_session, ) # List activities - result = activity.list(db_session, page=1, page_size=3) + result = await activity.list(db_session, page=1, page_size=3) assert len(result.items) == 3 assert result.total == 5 @@ -269,38 +267,38 @@ def test_list_activities(db_session: Session): assert result.page_size == 3 -def test_list_activities_second_page(db_session: Session): +async def test_list_activities_second_page(db_session: Session): """Test listing activities on second page.""" for i in range(5): - activity.create( + await activity.create( task_name=f"task_{i}", message=f"Message {i}", session=db_session, ) - result = activity.list(db_session, page=2, page_size=3) + result = await activity.list(db_session, page=2, page_size=3) assert len(result.items) == 2 # Remaining items assert result.total == 5 assert result.page == 2 -def test_detail_activity(db_session: Session): +async def test_detail_activity(db_session: Session): """Test getting activity detail with all logs.""" - agent_id = activity.create( + agent_id = await activity.create( task_name="detailed_task", message="Initial", session=db_session, ) - activity.update( + await activity.update( agent_id=agent_id, message="Processing", percentage=50, session=db_session, ) - activity.complete(agent_id=agent_id, session=db_session) + await activity.complete(agent_id=agent_id, session=db_session) - result = activity.detail(db_session, agent_id) + result = await activity.detail(db_session, agent_id) assert result is not None assert result.agent_id == agent_id @@ -311,33 +309,33 @@ def test_detail_activity(db_session: Session): assert result.logs[2].status == Status.COMPLETE -def test_detail_activity_not_found(db_session: Session): +async def test_detail_activity_not_found(db_session: Session): """Test getting detail for non-existent activity returns None.""" fake_id = uuid.uuid4() - result = activity.detail(db_session, fake_id) + result = await activity.detail(db_session, fake_id) assert result is None -def test_detail_activity_with_string_id(db_session: Session): +async def test_detail_activity_with_string_id(db_session: Session): """Test getting activity detail with string agent_id.""" - agent_id = activity.create( + agent_id = await activity.create( task_name="string_id_task", message="Test", session=db_session, ) # Use string ID - result = activity.detail(db_session, str(agent_id)) + result = await activity.detail(db_session, str(agent_id)) assert result is not None assert result.agent_id == agent_id -def test_create_activity_with_custom_agent_id(db_session: Session): +async def test_create_activity_with_custom_agent_id(db_session: Session): """Test creating activity with a custom agent_id.""" custom_id = uuid.uuid4() - agent_id = activity.create( + agent_id = await activity.create( task_name="custom_id_task", message="Test", agent_id=custom_id, @@ -350,10 +348,10 @@ def test_create_activity_with_custom_agent_id(db_session: Session): assert activity_record is not None -def test_create_activity_with_string_agent_id(db_session: Session): +async def test_create_activity_with_string_agent_id(db_session: Session): """Test creating activity with a string agent_id.""" custom_id = uuid.uuid4() - agent_id = activity.create( + agent_id = await activity.create( task_name="string_agent_id_task", message="Test", agent_id=str(custom_id), @@ -363,12 +361,9 @@ def test_create_activity_with_string_agent_id(db_session: Session): assert agent_id == custom_id -# --- Metadata Tests --- - - -def test_create_activity_with_metadata(db_session: Session): +async def test_create_activity_with_metadata(db_session: Session): """Test creating activity with metadata.""" - agent_id = activity.create( + agent_id = await activity.create( task_name="metadata_task", message="Test with metadata", session=db_session, @@ -380,9 +375,9 @@ def test_create_activity_with_metadata(db_session: Session): assert activity_record.metadata_ == {"organization_id": "org-123", "user_id": "user-456"} -def test_create_activity_without_metadata(db_session: Session): +async def test_create_activity_without_metadata(db_session: Session): """Test that metadata is None by default.""" - agent_id = activity.create( + agent_id = await activity.create( task_name="no_metadata_task", message="Test without metadata", session=db_session, @@ -393,22 +388,22 @@ def test_create_activity_without_metadata(db_session: Session): assert activity_record.metadata_ is None -def test_list_activities_with_metadata_filter(db_session: Session): +async def test_list_activities_with_metadata_filter(db_session: Session): """Test filtering activities by metadata.""" # Create activities for different organizations - activity.create( + await activity.create( task_name="task_org_a", message="Org A task", session=db_session, metadata={"organization_id": "org-A"}, ) - activity.create( + await activity.create( task_name="task_org_a_2", message="Org A task 2", session=db_session, metadata={"organization_id": "org-A"}, ) - activity.create( + await activity.create( task_name="task_org_b", message="Org B task", session=db_session, @@ -416,7 +411,7 @@ def test_list_activities_with_metadata_filter(db_session: Session): ) # Filter by org-A - result = activity.list( + result = await activity.list( db_session, metadata_filter={"organization_id": "org-A"}, ) @@ -427,7 +422,7 @@ def test_list_activities_with_metadata_filter(db_session: Session): assert item.metadata["organization_id"] == "org-A" # Filter by org-B - result = activity.list( + result = await activity.list( db_session, metadata_filter={"organization_id": "org-B"}, ) @@ -436,22 +431,22 @@ def test_list_activities_with_metadata_filter(db_session: Session): assert result.items[0].metadata["organization_id"] == "org-B" # Filter by non-existent org - result = activity.list( + result = await activity.list( db_session, metadata_filter={"organization_id": "org-C"}, ) assert result.total == 0 -def test_list_activities_with_multiple_metadata_filters(db_session: Session): +async def test_list_activities_with_multiple_metadata_filters(db_session: Session): """Test filtering activities by multiple metadata fields.""" - activity.create( + await activity.create( task_name="task_1", message="User 1 in Org A", session=db_session, metadata={"organization_id": "org-A", "user_id": "user-1"}, ) - activity.create( + await activity.create( task_name="task_2", message="User 2 in Org A", session=db_session, @@ -459,37 +454,37 @@ def test_list_activities_with_multiple_metadata_filters(db_session: Session): ) # Filter by both org and user - result = activity.list( + result = await activity.list( db_session, metadata_filter={"organization_id": "org-A", "user_id": "user-1"}, ) assert result.total == 1 -def test_detail_activity_with_metadata(db_session: Session): +async def test_detail_activity_with_metadata(db_session: Session): """Test getting activity detail includes metadata.""" - agent_id = activity.create( + agent_id = await activity.create( task_name="detailed_metadata_task", message="Test", session=db_session, metadata={"organization_id": "org-123"}, ) - result = activity.detail(db_session, agent_id) + result = await activity.detail(db_session, agent_id) assert result is not None assert result.metadata == {"organization_id": "org-123"} -def test_detail_activity_with_metadata_filter_match(db_session: Session): +async def test_detail_activity_with_metadata_filter_match(db_session: Session): """Test detail returns activity when metadata filter matches.""" - agent_id = activity.create( + agent_id = await activity.create( task_name="filter_match_task", message="Test", session=db_session, metadata={"organization_id": "org-A"}, ) - result = activity.detail( + result = await activity.detail( db_session, agent_id, metadata_filter={"organization_id": "org-A"}, @@ -498,9 +493,9 @@ def test_detail_activity_with_metadata_filter_match(db_session: Session): assert result.agent_id == agent_id -def test_detail_activity_with_metadata_filter_no_match(db_session: Session): +async def test_detail_activity_with_metadata_filter_no_match(db_session: Session): """Test detail returns None when metadata filter doesn't match.""" - agent_id = activity.create( + agent_id = await activity.create( task_name="filter_no_match_task", message="Test", session=db_session, @@ -508,7 +503,7 @@ def test_detail_activity_with_metadata_filter_no_match(db_session: Session): ) # Try to access with wrong organization - result = activity.detail( + result = await activity.detail( db_session, agent_id, metadata_filter={"organization_id": "org-B"}, @@ -516,15 +511,15 @@ def test_detail_activity_with_metadata_filter_no_match(db_session: Session): assert result is None -def test_detail_activity_no_metadata_with_filter(db_session: Session): +async def test_detail_activity_no_metadata_with_filter(db_session: Session): """Test detail returns None when activity has no metadata but filter is applied.""" - agent_id = activity.create( + agent_id = await activity.create( task_name="no_metadata_with_filter", message="Test", session=db_session, ) - result = activity.detail( + result = await activity.detail( db_session, agent_id, metadata_filter={"organization_id": "org-A"}, @@ -532,28 +527,28 @@ def test_detail_activity_no_metadata_with_filter(db_session: Session): assert result is None -def test_list_metadata_accessible_as_attribute(db_session: Session): +async def test_list_metadata_accessible_as_attribute(db_session: Session): """Test that metadata is accessible as an attribute on schema objects.""" - activity.create( + await activity.create( task_name="list_metadata_task", message="Test", session=db_session, metadata={"key1": "value1", "key2": "value2"}, ) - result = activity.list(db_session) + result = await activity.list(db_session) assert result.total == 1 # Metadata is accessible as attribute for programmatic use assert result.items[0].metadata == {"key1": "value1", "key2": "value2"} -def test_metadata_excluded_from_serialization(db_session: Session): +async def test_metadata_excluded_from_serialization(db_session: Session): """Test that metadata is excluded from JSON/dict serialization by default. This prevents accidental leakage of tenant info through API responses. Users who want metadata in responses should explicitly include it. """ - agent_id = activity.create( + agent_id = await activity.create( task_name="serialization_test", message="Test", session=db_session, @@ -561,12 +556,12 @@ def test_metadata_excluded_from_serialization(db_session: Session): ) # List view - metadata excluded from serialization - result = activity.list(db_session) + result = await activity.list(db_session) item_dict = result.items[0].model_dump() assert "metadata" not in item_dict # Detail view - metadata excluded from serialization - detail = activity.detail(db_session, agent_id) + detail = await activity.detail(db_session, agent_id) assert detail is not None detail_dict = detail.model_dump() assert "metadata" not in detail_dict diff --git a/tests/test_config.py b/tests/test_config.py index af280fc..3aa6ec4 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,5 +1,3 @@ -"""Test configuration handling.""" - import os import pytest @@ -15,10 +13,10 @@ def test_default_table_prefix(self): config = Config() assert config.table_prefix == "agentexec_" - def test_default_queue_name(self): - """Test default queue name.""" + def test_default_queue_prefix(self): + """Test default queue prefix.""" config = Config() - assert config.queue_name == "agentexec_tasks" + assert config.queue_prefix == "agentexec_tasks" def test_default_num_workers(self): """Test default number of workers.""" @@ -107,11 +105,11 @@ def test_table_prefix_from_env(self): config = Config() assert config.table_prefix == "custom_" - def test_queue_name_from_env(self): - """Test queue_name from environment variable.""" + def test_queue_prefix_from_env(self): + """Test queue_prefix from environment variable (with backwards compat alias).""" os.environ["AGENTEXEC_QUEUE_NAME"] = "my_queue" config = Config() - assert config.queue_name == "my_queue" + assert config.queue_prefix == "my_queue" def test_graceful_shutdown_timeout_from_env(self): """Test graceful_shutdown_timeout from environment variable.""" @@ -191,7 +189,7 @@ def test_conf_is_config_instance(self): def test_conf_has_expected_attributes(self): """Test that CONF has all expected attributes.""" assert hasattr(CONF, "table_prefix") - assert hasattr(CONF, "queue_name") + assert hasattr(CONF, "queue_prefix") assert hasattr(CONF, "num_workers") assert hasattr(CONF, "graceful_shutdown_timeout") assert hasattr(CONF, "redis_url") diff --git a/tests/test_db.py b/tests/test_db.py index 4714751..01df665 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -1,134 +1,38 @@ -"""Test database session management.""" - import pytest -from sqlalchemy import create_engine, text -from sqlalchemy.orm import Session +from sqlalchemy import create_engine -from agentexec.core.db import ( - Base, - get_global_session, - remove_global_session, - set_global_session, -) +from agentexec.core.db import Base, configure_engine, get_session @pytest.fixture def test_engine(): - """Create a test SQLite engine.""" - engine = create_engine("sqlite:///:memory:", echo=False) + engine = create_engine("sqlite:///:memory:") + Base.metadata.create_all(bind=engine) + configure_engine(engine) yield engine engine.dispose() -@pytest.fixture(autouse=True) -def cleanup_session(): - """Cleanup global session after each test.""" - yield - try: - remove_global_session() - except Exception: - pass - - -def test_base_class_exists(): - """Test that Base class is exported and usable.""" - assert Base is not None - assert hasattr(Base, "metadata") - - -def test_set_global_session(test_engine): - """Test that set_global_session configures the session factory.""" - set_global_session(test_engine) - - # Should be able to get a session now - session = get_global_session() - assert isinstance(session, Session) - - -def test_get_global_session_returns_session(test_engine): - """Test that get_global_session returns a working session.""" - set_global_session(test_engine) - - session = get_global_session() - - # Verify it's a working session - result = session.execute(text("SELECT 1")) - assert result.scalar() == 1 - - -def test_get_global_session_singleton(test_engine): - """Test that get_global_session returns the same session instance.""" - set_global_session(test_engine) - - session1 = get_global_session() - session2 = get_global_session() - - # Should be the same session (scoped_session behavior) - assert session1 is session2 - - -def test_remove_global_session(test_engine): - """Test that remove_global_session closes the session.""" - set_global_session(test_engine) - - session1 = get_global_session() - remove_global_session() +def test_configure_engine(test_engine): + """configure_engine makes get_session available.""" + session = get_session() + assert session is not None + session.close() - # Getting session again should return a different instance - session2 = get_global_session() - # They should be different sessions after remove - assert session1 is not session2 +def test_get_session_context_manager(test_engine): + """get_session works as a context manager.""" + with get_session() as session: + assert session is not None -def test_session_with_tables(test_engine): - """Test that session works with table creation.""" - # Create the tables - Base.metadata.create_all(bind=test_engine) - - set_global_session(test_engine) - session = get_global_session() - - # Session should be able to query (though tables may be empty) - result = session.execute(text("SELECT name FROM sqlite_master WHERE type='table'")) - tables = [row[0] for row in result] - - # Tables from Base.metadata should exist - assert isinstance(tables, list) - - -def test_multiple_set_global_session_calls(test_engine): - """Test that multiple set_global_session calls work correctly.""" - set_global_session(test_engine) - session1 = get_global_session() - - # Create another engine - engine2 = create_engine("sqlite:///:memory:") - - # Reconfigure with new engine - set_global_session(engine2) - session2 = get_global_session() - - # Sessions should work with their respective engines - result = session2.execute(text("SELECT 1")) - assert result.scalar() == 1 - - engine2.dispose() - - -def test_session_lifecycle(): - """Test complete session lifecycle: set -> use -> remove.""" - engine = create_engine("sqlite:///:memory:") - - # Set - set_global_session(engine) - - # Use - session = get_global_session() - session.execute(text("SELECT 1")) - - # Remove - remove_global_session() - - # Cleanup - engine.dispose() +def test_get_session_without_configure_raises(): + """get_session raises if configure_engine hasn't been called.""" + import agentexec.core.db as db_module + old_factory = db_module._session_factory + db_module._session_factory = None + try: + with pytest.raises(RuntimeError, match="Database engine not configured"): + get_session() + finally: + db_module._session_factory = old_factory diff --git a/tests/test_kafka_integration.py b/tests/test_kafka_integration.py new file mode 100644 index 0000000..8393a0d --- /dev/null +++ b/tests/test_kafka_integration.py @@ -0,0 +1,399 @@ +"""Kafka backend integration tests. + +These tests run against a real Kafka broker. They are skipped if the +``aiokafka`` package is not installed or ``KAFKA_BOOTSTRAP_SERVERS`` is +not set. + +Run locally: + + docker compose -f docker-compose.kafka.yml up -d + + AGENTEXEC_STATE_BACKEND=agentexec.state.kafka \\ + KAFKA_BOOTSTRAP_SERVERS=localhost:9092 \\ + uv run pytest tests/test_kafka_integration.py -v + + docker compose -f docker-compose.kafka.yml down +""" + +from __future__ import annotations + +import asyncio +import os +import uuid + +import pytest +from pydantic import BaseModel + +_skip_reason = None + +if not os.environ.get("KAFKA_BOOTSTRAP_SERVERS"): + _skip_reason = "KAFKA_BOOTSTRAP_SERVERS not set" +else: + try: + import aiokafka # noqa: F401 + except ImportError: + _skip_reason = "aiokafka not installed (pip install agentexec[kafka])" + +if _skip_reason: + pytest.skip(_skip_reason, allow_module_level=True) + + +from agentexec.state import backend # noqa: E402 +from agentexec.state.kafka import Backend as KafkaBackend # noqa: E402 + +# Convenience alias to keep test code concise +_kb: KafkaBackend = backend # type: ignore[assignment] + + +class SampleResult(BaseModel): + status: str + value: int + + +class TaskContext(BaseModel): + query: str + + +pytestmark = pytest.mark.asyncio(loop_scope="module") + + +@pytest.fixture(autouse=True) +async def kafka_cleanup(): + """Ensure caches are clean before/after each test.""" + await _kb.state.clear() + yield + await _kb.state.clear() + + +@pytest.fixture(autouse=True, scope="module") +async def close_connections(): + """Close all Kafka connections once after the module completes.""" + yield + await _kb.close() + + +class TestKVStore: + async def test_store_set_and_get(self): + """Values written via store_set are readable from the cache.""" + key = f"test:kv:{uuid.uuid4()}" + await _kb.state.set(key, b"hello-world") + result = await _kb.state.get(key) + assert result == b"hello-world" + + async def test_store_get_missing_key(self): + """Reading a non-existent key returns None.""" + result = await _kb.state.get(f"test:missing:{uuid.uuid4()}") + assert result is None + + async def test_store_delete(self): + """Deleting a key removes it from the cache.""" + key = f"test:kv:{uuid.uuid4()}" + await _kb.state.set(key, b"to-delete") + assert await _kb.state.get(key) == b"to-delete" + + await _kb.state.delete(key) + assert await _kb.state.get(key) is None + + async def test_store_set_overwrites(self): + """A second store_set for the same key overwrites the value.""" + key = f"test:kv:{uuid.uuid4()}" + await _kb.state.set(key, b"v1") + await _kb.state.set(key, b"v2") + assert await _kb.state.get(key) == b"v2" + + +class TestCounters: + async def test_incr_from_zero(self): + """Incrementing a non-existent counter starts at 1.""" + key = f"test:counter:{uuid.uuid4()}" + result = await _kb.state.counter_incr(key) + assert result == 1 + + async def test_incr_multiple(self): + """Multiple increments accumulate.""" + key = f"test:counter:{uuid.uuid4()}" + await _kb.state.counter_incr(key) + await _kb.state.counter_incr(key) + result = await _kb.state.counter_incr(key) + assert result == 3 + + async def test_decr(self): + """Decrement reduces the counter.""" + key = f"test:counter:{uuid.uuid4()}" + await _kb.state.counter_incr(key) + await _kb.state.counter_incr(key) + result = await _kb.state.counter_decr(key) + assert result == 1 + + +class TestSortedIndex: + async def test_index_add_and_range(self): + """Members added with scores can be queried by score range.""" + key = f"test:index:{uuid.uuid4()}" + await _kb.state.index_add(key, {"task_a": 100.0, "task_b": 200.0, "task_c": 300.0}) + + result = await _kb.state.index_range(key, 0.0, 250.0) + names = [item.decode() for item in result] + assert "task_a" in names + assert "task_b" in names + assert "task_c" not in names + + async def test_index_remove(self): + """Removed members no longer appear in range queries.""" + key = f"test:index:{uuid.uuid4()}" + await _kb.state.index_add(key, {"task_a": 100.0, "task_b": 200.0}) + await _kb.state.index_remove(key, "task_a") + + result = await _kb.state.index_range(key, 0.0, 999.0) + names = [item.decode() for item in result] + assert "task_a" not in names + assert "task_b" in names + + +class TestSerialization: + def test_roundtrip(self): + """serialize → deserialize preserves type and data.""" + original = SampleResult(status="ok", value=42) + data = _kb.serialize(original) + restored = _kb.deserialize(data) + assert type(restored) is SampleResult + assert restored == original + + def test_format_key_joins_with_dots(self): + """Kafka backend uses dots as key separators.""" + assert _kb.format_key("agentexec", "result", "123") == "agentexec.result.123" + + +class TestQueue: + async def test_push_and_pop(self): + """A pushed task can be popped from the queue.""" + # Use a unique queue name per test to avoid cross-test interference + q = f"kafka_test_{uuid.uuid4().hex[:8]}" + import json + + task_data = { + "task_name": "test_task", + "context": {"query": "hello"}, + "agent_id": str(uuid.uuid4()), + } + await _kb.queue.push(q, json.dumps(task_data)) + + result = await _kb.queue.pop(q, timeout=10) + assert result is not None + assert result["task_name"] == "test_task" + assert result["context"]["query"] == "hello" + + async def test_pop_empty_queue_returns_none(self): + """Popping an empty queue returns None after timeout.""" + q = f"kafka_empty_{uuid.uuid4().hex[:8]}" + result = await _kb.queue.pop(q, timeout=1) + assert result is None + + async def test_push_with_partition_key(self): + """Tasks with partition_key are routed deterministically.""" + q = f"kafka_pk_{uuid.uuid4().hex[:8]}" + import json + + task_data = { + "task_name": "keyed_task", + "context": {"query": "keyed"}, + "agent_id": str(uuid.uuid4()), + } + await _kb.queue.push(q, json.dumps(task_data), partition_key="user-123") + + result = await _kb.queue.pop(q, timeout=10) + assert result is not None + assert result["task_name"] == "keyed_task" + + async def test_multiple_push_pop_ordering(self): + """Tasks with the same partition key are consumed in order.""" + q = f"kafka_order_{uuid.uuid4().hex[:8]}" + import json + + ids = [str(uuid.uuid4()) for _ in range(3)] + for agent_id in ids: + await _kb.queue.push(q, json.dumps({ + "task_name": "order_test", + "context": {"query": "test"}, + "agent_id": agent_id, + }), partition_key="same-key") + + received = [] + for _ in range(3): + result = await _kb.queue.pop(q, timeout=10) + assert result is not None + received.append(result["agent_id"]) + + assert received == ids + + +class TestActivity: + async def test_create_and_get(self): + """Creating an activity makes it retrievable.""" + agent_id = uuid.uuid4() + await _kb.activity.create(agent_id, "test_task", "Agent queued", None) + + record = await _kb.activity.get(agent_id) + assert record is not None + assert record["agent_id"] == str(agent_id) + assert record["agent_type"] == "test_task" + assert len(record["logs"]) == 1 + assert record["logs"][0]["status"] == "queued" + assert record["logs"][0]["message"] == "Agent queued" + + async def test_append_log(self): + """Appending a log entry adds to the record.""" + agent_id = uuid.uuid4() + await _kb.activity.create(agent_id, "test_task", "Queued", None) + await _kb.activity.append_log(agent_id, "Processing", "running", 50) + + record = await _kb.activity.get(agent_id) + assert len(record["logs"]) == 2 + assert record["logs"][1]["status"] == "running" + assert record["logs"][1]["message"] == "Processing" + assert record["logs"][1]["percentage"] == 50 + + async def test_activity_lifecycle(self): + """Full lifecycle: create → update → complete.""" + agent_id = uuid.uuid4() + await _kb.activity.create(agent_id, "lifecycle_task", "Queued", None) + await _kb.activity.append_log(agent_id, "Started", "running", 0) + await _kb.activity.append_log(agent_id, "Halfway", "running", 50) + await _kb.activity.append_log(agent_id, "Done", "complete", 100) + + record = await _kb.activity.get(agent_id) + assert len(record["logs"]) == 4 + assert record["logs"][-1]["status"] == "complete" + assert record["logs"][-1]["percentage"] == 100 + + @pytest.mark.skip(reason="Aggregate queries accumulate across tests on shared topic") + async def test_activity_list_pagination(self): + """activity_list returns paginated results.""" + for i in range(5): + await _kb.activity.create(uuid.uuid4(), f"task_{i}", "Queued", None) + + rows, total = await _kb.activity.list(page=1, page_size=3) + assert total == 5 + assert len(rows) == 3 + + rows2, total2 = await _kb.activity.list(page=2, page_size=3) + assert total2 == 5 + assert len(rows2) == 2 + + @pytest.mark.skip(reason="Aggregate queries accumulate across tests on shared topic") + async def test_activity_count_active(self): + """count_active returns queued + running activities.""" + a1 = uuid.uuid4() + a2 = uuid.uuid4() + a3 = uuid.uuid4() + + await _kb.activity.create(a1, "task", "Queued", None) + await _kb.activity.create(a2, "task", "Queued", None) + await _kb.activity.create(a3, "task", "Queued", None) + + # Mark one as running, one as complete + await _kb.activity.append_log(a2, "Running", "running", 10) + await _kb.activity.append_log(a3, "Done", "complete", 100) + + count = await _kb.activity.count_active() + assert count == 2 # a1 (queued) + a2 (running) + + @pytest.mark.skip(reason="Aggregate queries accumulate across tests on shared topic") + async def test_activity_get_pending_ids(self): + """get_pending_ids returns agent_ids for queued/running activities.""" + a1 = uuid.uuid4() + a2 = uuid.uuid4() + a3 = uuid.uuid4() + + await _kb.activity.create(a1, "task", "Queued", None) + await _kb.activity.create(a2, "task", "Queued", None) + await _kb.activity.create(a3, "task", "Queued", None) + + await _kb.activity.append_log(a3, "Done", "complete", 100) + + pending = await _kb.activity.get_pending_ids() + pending_set = {str(p) for p in pending} + assert str(a1) in pending_set + assert str(a2) in pending_set + assert str(a3) not in pending_set + + async def test_activity_with_metadata(self): + """Metadata is stored and filterable.""" + agent_id = uuid.uuid4() + await _kb.activity.create( + agent_id, "task", "Queued", + metadata={"org_id": "org-123", "env": "test"}, + ) + + # Retrieve without filter + record = await _kb.activity.get(agent_id) + assert record["metadata"] == {"org_id": "org-123", "env": "test"} + + # Filter match + record = await _kb.activity.get(agent_id, metadata_filter={"org_id": "org-123"}) + assert record is not None + + # Filter mismatch + record = await _kb.activity.get(agent_id, metadata_filter={"org_id": "org-999"}) + assert record is None + + async def test_activity_get_nonexistent(self): + """Getting a non-existent activity returns None.""" + result = await _kb.activity.get(uuid.uuid4()) + assert result is None + + +class TestLogPubSub: + async def test_publish_and_subscribe(self): + """Published log messages arrive via subscribe.""" + received = [] + + async def subscriber(): + async for msg in _kb.state.log_subscribe(): + received.append(msg) + if len(received) >= 2: + break + + sub_task = asyncio.create_task(subscriber()) + await asyncio.sleep(2) + + await _kb.state.log_publish('{"level":"info","msg":"hello"}') + await _kb.state.log_publish('{"level":"info","msg":"world"}') + + # Wait for messages to arrive (with timeout) + try: + await asyncio.wait_for(sub_task, timeout=10) + except asyncio.TimeoutError: + sub_task.cancel() + try: + await sub_task + except asyncio.CancelledError: + pass + + assert len(received) >= 2 + assert '{"level":"info","msg":"hello"}' in received + assert '{"level":"info","msg":"world"}' in received + + +class TestConnection: + async def test_ensure_topic_idempotent(self): + """ensure_topic can be called multiple times without error.""" + topic = f"test_ensure_{uuid.uuid4().hex[:8]}" + await _kb.ensure_topic(topic) + await _kb.ensure_topic(topic) # Should not raise + + async def test_client_id_includes_worker_id(self): + """client_id includes worker_id when configured.""" + _kb.configure(worker_id="42") + cid = _kb._client_id("producer") + assert "42" in cid + assert "producer" in cid + + # Reset + _kb._worker_id = None + + async def test_produce_and_topic_creation(self): + """produce() auto-creates the topic if needed.""" + topic = f"test_produce_{uuid.uuid4().hex[:8]}" + await _kb.produce(topic, b"test-value", key=b"test-key") + # If we got here without error, produce and topic creation worked diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index cf81680..bf9213d 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -1,5 +1,3 @@ -"""Test Pipeline orchestration functionality.""" - from dataclasses import dataclass, field from unittest.mock import MagicMock diff --git a/tests/test_pipeline_flow.py b/tests/test_pipeline_flow.py index b8f0ff6..d7b9481 100644 --- a/tests/test_pipeline_flow.py +++ b/tests/test_pipeline_flow.py @@ -1,12 +1,3 @@ -"""Test Pipeline type flow validation. - -This module tests the type checking between pipeline steps, ensuring that: -- Return types from one step match parameter types of the next step -- Tuple returns are properly unpacked into multiple parameters -- Type mismatches are caught at validation time -- Subclass relationships are respected -""" - from dataclasses import dataclass, field from unittest.mock import MagicMock @@ -16,11 +7,6 @@ from agentexec.pipeline import Pipeline -# ============================================================================= -# Test Models -# ============================================================================= - - class Context(BaseModel): """Input context for pipeline tests.""" @@ -52,11 +38,6 @@ class Combined(BaseModel): b: str -# ============================================================================= -# Fixtures -# ============================================================================= - - @dataclass class MockWorkerContext: """Mock context for testing.""" @@ -78,11 +59,6 @@ def pipeline(mock_pool): return Pipeline(mock_pool) -# ============================================================================= -# Valid Flows - Single Value -# ============================================================================= - - class TestValidSingleValueFlows: """Test valid single-value flows between steps.""" @@ -148,11 +124,6 @@ async def consume_base(self, result: ResultA) -> ResultC: assert result.value == "from_derived" -# ============================================================================= -# Valid Flows - Tuple Unpacking -# ============================================================================= - - class TestValidTupleFlows: """Test valid tuple return/parameter flows.""" @@ -215,11 +186,6 @@ async def combine(self, left: ResultA, right: ResultA) -> Combined: assert result.b == "right:data" -# ============================================================================= -# Invalid Flows - Count Mismatches -# ============================================================================= - - class TestInvalidCountMismatches: """Test that count mismatches between steps are caught.""" @@ -279,11 +245,6 @@ async def second(self) -> ResultC: await pipeline.run(Context(value="42")) -# ============================================================================= -# Invalid Flows - Type Mismatches -# ============================================================================= - - class TestInvalidTypeMismatches: """Test that type mismatches between steps are caught.""" @@ -328,11 +289,6 @@ class UnrelatedPipeline(pipeline.Base): pipeline._validate_type_flow() -# ============================================================================= -# Invalid Flows - Final Step Returns Tuple -# ============================================================================= - - class TestInvalidFinalStepTuple: """Test that final step returning tuple is rejected.""" @@ -357,11 +313,6 @@ class TupleFinalPipeline(pipeline.Base): pipeline._validate_type_flow() -# ============================================================================= -# Edge Cases -# ============================================================================= - - class TestInvalidNoSteps: """Test that pipelines with no steps are rejected.""" diff --git a/tests/test_public_api.py b/tests/test_public_api.py index 8ce615f..4d32c44 100644 --- a/tests/test_public_api.py +++ b/tests/test_public_api.py @@ -1,5 +1,3 @@ -"""Test that the public API is properly exposed.""" - import uuid import pytest diff --git a/tests/test_queue.py b/tests/test_queue.py index db71c3f..6217727 100644 --- a/tests/test_queue.py +++ b/tests/test_queue.py @@ -1,5 +1,3 @@ -"""Test task queue operations.""" - import json import uuid @@ -8,7 +6,8 @@ from pydantic import BaseModel import agentexec as ax -from agentexec.core.queue import Priority, dequeue, enqueue +from agentexec.core.queue import Priority, enqueue +from agentexec.state import backend class SampleContext(BaseModel): @@ -20,31 +19,17 @@ class SampleContext(BaseModel): @pytest.fixture def fake_redis(monkeypatch): - """Setup fake redis for state backend with shared state.""" - import fakeredis - - # Create a shared FakeServer so sync and async clients share data - server = fakeredis.FakeServer() - fake_redis_sync = fakeredis.FakeRedis(server=server, decode_responses=False) - fake_redis = fake_aioredis.FakeRedis(server=server, decode_responses=False) - - def get_fake_sync_client(): - return fake_redis_sync - - def get_fake_async_client(): - return fake_redis - - monkeypatch.setattr("agentexec.state.redis_backend._get_sync_client", get_fake_sync_client) - monkeypatch.setattr("agentexec.state.redis_backend._get_async_client", get_fake_async_client) - - yield fake_redis + """Setup fake redis for state backend.""" + fake = fake_aioredis.FakeRedis(decode_responses=False) + monkeypatch.setattr(backend, "_client", fake) + yield fake @pytest.fixture def mock_activity_create(monkeypatch): """Mock activity.create to avoid database dependency.""" - def mock_create(*args, **kwargs): + async def mock_create(*args, **kwargs): return uuid.uuid4() monkeypatch.setattr("agentexec.core.task.activity.create", mock_create) @@ -65,9 +50,8 @@ async def test_enqueue_creates_task(fake_redis, mock_activity_create) -> None: assert task is not None assert task.task_name == "test_task" assert isinstance(task.agent_id, uuid.UUID) - assert isinstance(task.context, SampleContext) - assert task.context.message == "test" - assert task.context.value == 42 + assert task.context["message"] == "test" + assert task.context["value"] == 42 async def test_enqueue_pushes_to_redis(fake_redis, mock_activity_create) -> None: @@ -77,7 +61,7 @@ async def test_enqueue_pushes_to_redis(fake_redis, mock_activity_create) -> None task = await enqueue("test_task", ctx) # Check Redis has the task - task_json = await fake_redis.rpop(ax.CONF.queue_name) + task_json = await fake_redis.rpop(ax.CONF.queue_prefix) assert task_json is not None task_data = json.loads(task_json) @@ -94,7 +78,7 @@ async def test_enqueue_low_priority_lpush(fake_redis, mock_activity_create) -> N # LPUSH adds to left, RPOP takes from right # So we should use LPOP to see it - task_json = await fake_redis.lpop(ax.CONF.queue_name) + task_json = await fake_redis.lpop(ax.CONF.queue_prefix) assert task_json is not None @@ -107,21 +91,11 @@ async def test_enqueue_high_priority_rpush(fake_redis, mock_activity_create) -> await enqueue("high_task", SampleContext(message="high"), priority=Priority.HIGH) # High priority should be at the front (RPOP side) - task_json = await fake_redis.rpop(ax.CONF.queue_name) + task_json = await fake_redis.rpop(ax.CONF.queue_prefix) task_data = json.loads(task_json) assert task_data["task_name"] == "high_task" -async def test_enqueue_custom_queue_name(fake_redis, mock_activity_create) -> None: - """Test enqueue with custom queue name.""" - ctx = SampleContext(message="custom") - - await enqueue("test_task", ctx, queue_name="custom_queue") - - # Check custom queue - task_json = await fake_redis.rpop("custom_queue") - assert task_json is not None - async def test_dequeue_returns_task_data(fake_redis) -> None: """Test that dequeue returns parsed task data.""" @@ -131,10 +105,10 @@ async def test_dequeue_returns_task_data(fake_redis) -> None: "context": {"message": "dequeued", "value": 100}, "agent_id": str(uuid.uuid4()), } - await fake_redis.lpush(ax.CONF.queue_name, json.dumps(task_data).encode()) + await fake_redis.lpush(ax.CONF.queue_prefix, json.dumps(task_data).encode()) # Dequeue - result = await dequeue(timeout=1) + result = await backend.queue.pop(timeout=1) assert result is not None assert result["task_name"] == "test_task" @@ -145,25 +119,11 @@ async def test_dequeue_returns_task_data(fake_redis) -> None: async def test_dequeue_returns_none_on_empty_queue(fake_redis) -> None: """Test that dequeue returns None when queue is empty.""" # timeout=1 because timeout=0 means block indefinitely in Redis BRPOP - result = await dequeue(timeout=1) + result = await backend.queue.pop(timeout=1) assert result is None -async def test_dequeue_custom_queue_name(fake_redis) -> None: - """Test dequeue with custom queue name.""" - task_data = { - "task_name": "custom_task", - "context": {"message": "test"}, - "agent_id": str(uuid.uuid4()), - } - await fake_redis.lpush("custom_queue", json.dumps(task_data).encode()) - - result = await dequeue(queue_name="custom_queue", timeout=1) - - assert result is not None - assert result["task_name"] == "custom_task" - async def test_dequeue_brpop_behavior(fake_redis) -> None: """Test that dequeue uses BRPOP (right side of list).""" @@ -171,11 +131,11 @@ async def test_dequeue_brpop_behavior(fake_redis) -> None: task1 = {"task_name": "first", "context": {}, "agent_id": str(uuid.uuid4())} task2 = {"task_name": "second", "context": {}, "agent_id": str(uuid.uuid4())} - await fake_redis.lpush(ax.CONF.queue_name, json.dumps(task1).encode()) - await fake_redis.lpush(ax.CONF.queue_name, json.dumps(task2).encode()) + await fake_redis.lpush(ax.CONF.queue_prefix, json.dumps(task1).encode()) + await fake_redis.lpush(ax.CONF.queue_prefix, json.dumps(task2).encode()) # BRPOP should get the first task (oldest) from the right - result = await dequeue(timeout=1) + result = await backend.queue.pop(timeout=1) assert result is not None assert result["task_name"] == "first" @@ -188,7 +148,7 @@ async def test_enqueue_dequeue_roundtrip(fake_redis, mock_activity_create) -> No task = await enqueue("roundtrip_task", ctx) # Dequeue - result = await dequeue(timeout=1) + result = await backend.queue.pop(timeout=1) assert result is not None assert result["task_name"] == "roundtrip_task" @@ -207,6 +167,6 @@ async def test_multiple_enqueue_fifo_order(fake_redis, mock_activity_create) -> # Dequeue should be in FIFO order for i in range(3): - result = await dequeue(timeout=1) + result = await backend.queue.pop(timeout=1) assert result is not None assert result["task_name"] == f"task_{i}" diff --git a/tests/test_queue_partitions.py b/tests/test_queue_partitions.py new file mode 100644 index 0000000..deac6a7 --- /dev/null +++ b/tests/test_queue_partitions.py @@ -0,0 +1,173 @@ +"""Tests for the partitioned Redis queue — SCAN-based dequeue with per-partition locking.""" + +import json +import uuid + +import pytest +from fakeredis import aioredis as fake_aioredis +from pydantic import BaseModel + +import agentexec as ax +from agentexec.state import backend + + +def _task_json(task_name: str = "test", **overrides) -> str: + data = { + "task_name": task_name, + "context": {"message": "hello"}, + "agent_id": str(uuid.uuid4()), + **overrides, + } + return json.dumps(data) + + +@pytest.fixture +def fake_redis(monkeypatch): + fake = fake_aioredis.FakeRedis(decode_responses=False) + monkeypatch.setattr(backend, "_client", fake) + yield fake + + +class TestPartitionRouting: + async def test_push_to_default_queue(self, fake_redis): + await backend.queue.push(_task_json("t1")) + assert await fake_redis.llen(ax.CONF.queue_prefix) == 1 + + async def test_push_to_partition_queue(self, fake_redis): + await backend.queue.push(_task_json("t1"), partition_key="user:42") + partition_key = f"{ax.CONF.queue_prefix}:user:42" + assert await fake_redis.llen(partition_key) == 1 + assert await fake_redis.llen(ax.CONF.queue_prefix) == 0 + + async def test_pop_from_default_queue_no_lock(self, fake_redis): + """Default queue tasks are popped without acquiring a lock.""" + await backend.queue.push(_task_json("t1")) + result = await backend.queue.pop(timeout=1) + + assert result is not None + assert result["task_name"] == "t1" + + # No lock key should exist for the default queue + keys = [k async for k in fake_redis.scan_iter(match=b"*:lock")] + assert len(keys) == 0 + + +class TestPartitionLocking: + async def test_pop_acquires_lock_for_partition(self, fake_redis): + """Popping a partitioned task acquires its lock.""" + await backend.queue.push(_task_json("t1"), partition_key="user:42") + result = await backend.queue.pop(timeout=1) + + assert result is not None + lock_key = f"{ax.CONF.queue_prefix}:user:42:lock".encode() + assert await fake_redis.exists(lock_key) + + async def test_locked_partition_is_skipped(self, fake_redis): + """A partition with a held lock is skipped during pop.""" + await backend.queue.push(_task_json("t1"), partition_key="user:42") + + # Pre-acquire the lock + lock_key = f"{ax.CONF.queue_prefix}:user:42:lock" + await fake_redis.set(lock_key, b"1") + + result = await backend.queue.pop(timeout=1) + assert result is None + + async def test_complete_releases_lock(self, fake_redis): + """complete() deletes the partition lock.""" + await backend.queue.push(_task_json("t1"), partition_key="user:42") + await backend.queue.pop(timeout=1) + + lock_key = f"{ax.CONF.queue_prefix}:user:42:lock".encode() + assert await fake_redis.exists(lock_key) + + await backend.queue.complete("user:42") + assert not await fake_redis.exists(lock_key) + + async def test_complete_noop_for_none(self, fake_redis): + """complete(None) is a no-op for unpartitioned tasks.""" + await backend.queue.complete(None) + # No exception, no lock keys created + keys = [k async for k in fake_redis.scan_iter(match=b"*:lock")] + assert len(keys) == 0 + + +class TestMultiPartitionDequeue: + async def test_pops_from_unlocked_partition(self, fake_redis): + """With one locked and one unlocked partition, pop picks the unlocked one.""" + await backend.queue.push(_task_json("locked"), partition_key="user:1") + await backend.queue.push(_task_json("unlocked"), partition_key="user:2") + + # Lock user:1 + lock_key = f"{ax.CONF.queue_prefix}:user:1:lock" + await fake_redis.set(lock_key, b"1") + + result = await backend.queue.pop(timeout=1) + assert result is not None + assert result["task_name"] == "unlocked" + + async def test_pops_default_and_partition_interleaved(self, fake_redis): + """Tasks from default and partitioned queues are both reachable.""" + await backend.queue.push(_task_json("default_task")) + await backend.queue.push(_task_json("partitioned_task"), partition_key="org:99") + + results = [] + for _ in range(3): + r = await backend.queue.pop(timeout=1) + if r: + results.append(r["task_name"]) + # Release the lock if it was a partition task + if r["task_name"] == "partitioned_task": + await backend.queue.complete("org:99") + + assert sorted(results) == ["default_task", "partitioned_task"] + + async def test_serialization_within_partition(self, fake_redis): + """Only one task per partition can be in-flight at a time.""" + await backend.queue.push(_task_json("first"), partition_key="user:1") + await backend.queue.push(_task_json("second"), partition_key="user:1") + + # Pop first task — acquires lock + first = await backend.queue.pop(timeout=1) + assert first is not None + assert first["task_name"] == "first" + + # Second pop should skip user:1 (locked) and find nothing else + second = await backend.queue.pop(timeout=1) + assert second is None + + # After completing, second task becomes available + await backend.queue.complete("user:1") + second = await backend.queue.pop(timeout=1) + assert second is not None + assert second["task_name"] == "second" + + async def test_independent_partitions_are_concurrent(self, fake_redis): + """Different partitions can have tasks in-flight simultaneously.""" + await backend.queue.push(_task_json("user1_task"), partition_key="user:1") + await backend.queue.push(_task_json("user2_task"), partition_key="user:2") + + first = await backend.queue.pop(timeout=1) + assert first is not None + + second = await backend.queue.pop(timeout=1) + assert second is not None + + # Both tasks popped — different partitions, different locks + names = sorted([first["task_name"], second["task_name"]]) + assert names == ["user1_task", "user2_task"] + + async def test_empty_queue_returns_none(self, fake_redis): + result = await backend.queue.pop(timeout=1) + assert result is None + + async def test_high_priority_goes_to_front_of_partition(self, fake_redis): + """High priority tasks within a partition are popped first.""" + await backend.queue.push(_task_json("low"), partition_key="user:1") + await backend.queue.push( + _task_json("high"), partition_key="user:1", high_priority=True, + ) + + result = await backend.queue.pop(timeout=1) + assert result is not None + assert result["task_name"] == "high" diff --git a/tests/test_results.py b/tests/test_results.py index 01c9b43..43c4197 100644 --- a/tests/test_results.py +++ b/tests/test_results.py @@ -1,5 +1,3 @@ -"""Test task result storage and retrieval.""" - import asyncio import uuid from unittest.mock import AsyncMock, patch @@ -8,64 +6,53 @@ from pydantic import BaseModel import agentexec as ax -from agentexec.core.results import gather, get_result +from agentexec.core.results import gather, get_result, _get_result class SampleContext(BaseModel): - """Sample context for result tests.""" - message: str class SampleResult(BaseModel): - """Sample result model for tests.""" - status: str value: int class ComplexResult(BaseModel): - """Complex result model with nested data.""" - items: list[dict[str, int]] nested: dict[str, list[int]] @pytest.fixture -def mock_state(): - """Mock the state module's aget_result function.""" - with patch("agentexec.core.results.state") as mock: +def mock_get_result(): + """Mock the internal _get_result function.""" + with patch("agentexec.core.results._get_result") as mock: yield mock -async def test_get_result_returns_deserialized_data(mock_state) -> None: - """Test that get_result retrieves data from state.""" +async def test_get_result_returns_deserialized_data(mock_get_result) -> None: task = ax.Task( task_name="test_task", - context=SampleContext(message="test"), + context={"message": "test"}, agent_id=uuid.uuid4(), ) expected_result = SampleResult(status="success", value=42) - - # Mock aget_result to return the expected result - mock_state.aget_result = AsyncMock(return_value=expected_result) + mock_get_result.return_value = expected_result result = await get_result(task, timeout=1) assert result == expected_result - mock_state.aget_result.assert_called_once_with(task.agent_id) + mock_get_result.assert_called_once_with(task.agent_id) -async def test_get_result_polls_until_available(mock_state) -> None: - """Test that get_result polls until result is available.""" +async def test_get_result_polls_until_available(mock_get_result) -> None: task = ax.Task( task_name="test_task", - context=SampleContext(message="test"), + context={"message": "test"}, agent_id=uuid.uuid4(), ) expected_result = SampleResult(status="delayed", value=100) - # Return None first, then the result call_count = 0 async def delayed_result(agent_id): @@ -75,7 +62,7 @@ async def delayed_result(agent_id): return None return expected_result - mock_state.aget_result = delayed_result + mock_get_result.side_effect = delayed_result result = await get_result(task, timeout=5) @@ -83,46 +70,41 @@ async def delayed_result(agent_id): assert call_count == 3 -async def test_get_result_timeout(mock_state) -> None: - """Test that get_result raises TimeoutError if result not available.""" +async def test_get_result_timeout(mock_get_result) -> None: task = ax.Task( task_name="test_task", - context=SampleContext(message="test"), + context={"message": "test"}, agent_id=uuid.uuid4(), ) - - # Always return None to trigger timeout - mock_state.aget_result = AsyncMock(return_value=None) + mock_get_result.return_value = None with pytest.raises(TimeoutError, match=f"Result for {task.agent_id} not available"): await get_result(task, timeout=1) -async def test_gather_multiple_tasks(mock_state) -> None: - """Test that gather waits for multiple tasks and returns results.""" +async def test_gather_multiple_tasks(mock_get_result) -> None: task1 = ax.Task( task_name="task1", - context=SampleContext(message="test1"), + context={"message": "test1"}, agent_id=uuid.uuid4(), ) task2 = ax.Task( task_name="task2", - context=SampleContext(message="test2"), + context={"message": "test2"}, agent_id=uuid.uuid4(), ) result1 = SampleResult(status="task1", value=100) result2 = SampleResult(status="task2", value=200) - # Mock to return different results for different agent_ids - async def mock_aget_result(agent_id): + async def mock_result(agent_id): if agent_id == task1.agent_id: return result1 elif agent_id == task2.agent_id: return result2 return None - mock_state.aget_result = mock_aget_result + mock_get_result.side_effect = mock_result results = await gather(task1, task2) @@ -130,53 +112,48 @@ async def mock_aget_result(agent_id): assert len(results) == 2 -async def test_gather_single_task(mock_state) -> None: - """Test that gather works with a single task.""" +async def test_gather_single_task(mock_get_result) -> None: task = ax.Task( task_name="single_task", - context=SampleContext(message="test"), + context={"message": "test"}, agent_id=uuid.uuid4(), ) expected = SampleResult(status="single", value=1) - mock_state.aget_result = AsyncMock(return_value=expected) + mock_get_result.return_value = expected results = await gather(task) assert results == (expected,) -async def test_gather_preserves_order(mock_state) -> None: - """Test that gather returns results in the same order as input tasks.""" +async def test_gather_preserves_order(mock_get_result) -> None: tasks = [ ax.Task( task_name=f"task{i}", - context=SampleContext(message=f"msg{i}"), + context={"message": f"msg{i}"}, agent_id=uuid.uuid4(), ) for i in range(5) ] - # Create results mapped to task agent_ids results_map = {task.agent_id: SampleResult(status=f"result_{i}", value=i) for i, task in enumerate(tasks)} - async def mock_aget_result(agent_id): + async def mock_result(agent_id): return results_map.get(agent_id) - mock_state.aget_result = mock_aget_result + mock_get_result.side_effect = mock_result results = await gather(*tasks) - # Results should be in task order expected = tuple(SampleResult(status=f"result_{i}", value=i) for i in range(5)) assert results == expected -async def test_get_result_with_complex_object(mock_state) -> None: - """Test that get_result handles complex BaseModel objects.""" +async def test_get_result_with_complex_object(mock_get_result) -> None: task = ax.Task( task_name="test_task", - context=SampleContext(message="test"), + context={"message": "test"}, agent_id=uuid.uuid4(), ) @@ -184,7 +161,7 @@ async def test_get_result_with_complex_object(mock_state) -> None: items=[{"a": 1}, {"b": 2}], nested={"key": [1, 2, 3]}, ) - mock_state.aget_result = AsyncMock(return_value=expected) + mock_get_result.return_value = expected result = await get_result(task, timeout=1) diff --git a/tests/test_runners.py b/tests/test_runners.py index dd9d182..cd1763a 100644 --- a/tests/test_runners.py +++ b/tests/test_runners.py @@ -1,5 +1,3 @@ -"""Test runner base classes and functionality.""" - import uuid import pytest @@ -99,14 +97,14 @@ def test_report_status_function_docstring(self): assert report_fn.__doc__ is not None assert "progress" in report_fn.__doc__.lower() - def test_report_status_updates_activity(self, db_session, monkeypatch): + async def test_report_status_updates_activity(self, db_session, monkeypatch): """Test that report_status function calls activity.update.""" agent_id = uuid.uuid4() # Track calls to activity.update update_calls = [] - def mock_update(*args, **kwargs): + async def mock_update(*args, **kwargs): update_calls.append(kwargs) return True @@ -115,7 +113,7 @@ def mock_update(*args, **kwargs): tools = _RunnerTools(agent_id) report_fn = tools.report_status - result = report_fn("Working on task", 50) + result = await report_fn("Working on task", 50) assert result == "Status updated" assert len(update_calls) == 1 diff --git a/tests/test_schedule.py b/tests/test_schedule.py index 5142dad..ebd6fb2 100644 --- a/tests/test_schedule.py +++ b/tests/test_schedule.py @@ -1,5 +1,3 @@ -"""Tests for scheduled task support.""" - import time import uuid from datetime import datetime @@ -11,14 +9,29 @@ from pydantic import BaseModel import agentexec as ax -from agentexec import state +from agentexec import state, schedule +from agentexec.core.queue import enqueue from agentexec.schedule import ( REPEAT_FOREVER, ScheduledTask, - tick, - _queue_key, - _schedule_key, + register, ) +from agentexec.state import backend + + +async def tick(): + """Test helper — replicates the pool's schedule tick logic.""" + for task in await backend.schedule.get_due(): + await enqueue( + task.task_name, + context=backend.deserialize(task.context), + metadata=task.metadata, + ) + if task.repeat == 0: + await backend.schedule.remove(task.key) + else: + task.advance() + await backend.schedule.register(task) class RefreshContext(BaseModel): @@ -26,40 +39,51 @@ class RefreshContext(BaseModel): ttl: int = 300 -@pytest.fixture -def fake_redis(monkeypatch): - """Setup fake redis for state backend with shared state.""" - import fakeredis +def _index_key() -> str: + return backend.format_key(ax.CONF.key_prefix, "schedules") - server = fakeredis.FakeServer() - fake_redis_sync = fakeredis.FakeRedis(server=server, decode_responses=False) - fake_redis_async = fake_aioredis.FakeRedis(server=server, decode_responses=False) - def get_fake_sync_client(): - return fake_redis_sync +def _data_key() -> str: + return backend.format_key(ax.CONF.key_prefix, "schedules", "data") - def get_fake_async_client(): - return fake_redis_async - monkeypatch.setattr("agentexec.state.redis_backend._get_sync_client", get_fake_sync_client) - monkeypatch.setattr("agentexec.state.redis_backend._get_async_client", get_fake_async_client) +async def _get_schedule(fake_redis, task_name: str) -> ScheduledTask | None: + """Find a schedule by task_name in the hash.""" + all_data = await fake_redis.hgetall(_data_key()) + for key, data in all_data.items(): + st = ScheduledTask.model_validate_json(data) + if st.task_name == task_name: + return st + return None - yield fake_redis_sync + +async def _force_due(fake_redis, task_name: str) -> ScheduledTask: + """Set a schedule's next_run to the past so tick() picks it up.""" + st = await _get_schedule(fake_redis, task_name) + if st is None: + raise ValueError(f"No schedule found for {task_name}") + st.next_run = time.time() - 10 + await fake_redis.hset(_data_key(), st.key, st.model_dump_json().encode()) + await fake_redis.zadd(_index_key(), {st.key: st.next_run}) + return st @pytest.fixture -def mock_activity_create(monkeypatch): - """Mock activity.create to avoid database dependency.""" +def fake_redis(monkeypatch): + fake = fake_aioredis.FakeRedis(decode_responses=False) + monkeypatch.setattr(backend, "_client", fake) + yield fake - def mock_create(*args, **kwargs): - return uuid.uuid4() +@pytest.fixture +def mock_activity_create(monkeypatch): + async def mock_create(*args, **kwargs): + return uuid.uuid4() monkeypatch.setattr("agentexec.core.task.activity.create", mock_create) @pytest.fixture def pool(): - """Create a Pool with a registered task for scheduling tests.""" p = ax.Pool(database_url="sqlite:///") @p.task("refresh_cache") @@ -69,11 +93,6 @@ async def refresh(agent_id: UUID, context: RefreshContext): return p -# --------------------------------------------------------------------------- -# ScheduledTask model -# --------------------------------------------------------------------------- - - class TestScheduledTaskModel: def test_default_repeat_is_forever(self): ctx = RefreshContext(scope="test") @@ -81,7 +100,6 @@ def test_default_repeat_is_forever(self): task_name="test", context=state.backend.serialize(ctx), cron="*/5 * * * *", - ) assert st.repeat == REPEAT_FOREVER assert st.repeat == -1 @@ -92,27 +110,22 @@ def test_next_run_returns_future_timestamp(self): task_name="test", context=state.backend.serialize(ctx), cron="*/5 * * * *", - ) now = time.time() nxt = st._next_after(now) assert nxt > now def test_next_run_respects_anchor(self): - """Two calls with different anchors produce different results.""" ctx = RefreshContext(scope="test") st = ScheduledTask( task_name="test", context=state.backend.serialize(ctx), - cron="0 * * * *", # top of every hour - + cron="0 * * * *", ) anchor_a = 1_700_000_000.0 anchor_b = anchor_a + 3600 - next_a = st._next_after(anchor_a) next_b = st._next_after(anchor_b) - assert next_b > next_a assert next_b - next_a == pytest.approx(3600, abs=1) @@ -122,7 +135,6 @@ def test_cron_every_minute(self): task_name="test", context=state.backend.serialize(ctx), cron="* * * * *", - ) now = time.time() nxt = st._next_after(now) @@ -137,10 +149,8 @@ def test_roundtrip_serialization(self): repeat=5, next_run=time.time() + 600, ) - json_str = st.model_dump_json() restored = ScheduledTask.model_validate_json(json_str) - assert restored.task_name == "refresh" restored_ctx = state.backend.deserialize(restored.context) assert isinstance(restored_ctx, RefreshContext) @@ -159,115 +169,80 @@ def test_auto_generated_fields(self): assert st.created_at > 0 assert st.next_run > 0 - -# --------------------------------------------------------------------------- -# pool.add_schedule() -# --------------------------------------------------------------------------- + def test_key_includes_task_name_and_cron(self): + ctx = RefreshContext(scope="test") + st = ScheduledTask( + task_name="research", + context=state.backend.serialize(ctx), + cron="*/5 * * * *", + ) + assert st.key.startswith("research:*/5 * * * *:") class TestPoolAddSchedule: - def test_schedule_stores_in_redis(self, fake_redis, pool): + def test_schedule_defers_registration(self, pool): pool.add_schedule("refresh_cache", "*/5 * * * *", RefreshContext(scope="all")) + assert len(pool._pending_schedules) == 1 + sched = pool._pending_schedules[0] + assert sched["task_name"] == "refresh_cache" + assert sched["every"] == "*/5 * * * *" - data = fake_redis.get(_schedule_key("refresh_cache")) - assert data is not None - - st = ScheduledTask.model_validate_json(data) - assert st.task_name == "refresh_cache" - ctx = state.backend.deserialize(st.context) - assert isinstance(ctx, RefreshContext) - assert ctx.scope == "all" - - def test_schedule_indexes_in_sorted_set(self, fake_redis, pool): - pool.add_schedule("refresh_cache", "*/5 * * * *", RefreshContext(scope="all")) - - members = fake_redis.zrange(_queue_key(), 0, -1, withscores=True) - assert len(members) == 1 - - def test_schedule_rejects_unregistered_task(self, fake_redis, pool): + def test_schedule_rejects_unregistered_task(self, pool): with pytest.raises(ValueError, match="not registered"): pool.add_schedule("nonexistent_task", "*/5 * * * *", RefreshContext(scope="all")) - def test_schedule_with_metadata(self, fake_redis, pool): + def test_schedule_with_metadata(self, pool): pool.add_schedule( "refresh_cache", "*/5 * * * *", RefreshContext(scope="all"), metadata={"org_id": "org-123"}, ) - data = fake_redis.get(_schedule_key("refresh_cache")) - st = ScheduledTask.model_validate_json(data) - assert st.metadata == {"org_id": "org-123"} + assert pool._pending_schedules[0]["metadata"] == {"org_id": "org-123"} - def test_schedule_with_repeat(self, fake_redis, pool): + def test_schedule_with_repeat(self, pool): pool.add_schedule( "refresh_cache", "*/5 * * * *", RefreshContext(scope="all"), repeat=3, ) - data = fake_redis.get(_schedule_key("refresh_cache")) - st = ScheduledTask.model_validate_json(data) - assert st.repeat == 3 + assert pool._pending_schedules[0]["repeat"] == 3 - def test_schedule_is_idempotent(self, fake_redis, pool): - """Calling add_schedule twice for the same task overwrites, not duplicates.""" - pool.add_schedule("refresh_cache", "*/5 * * * *", RefreshContext(scope="v1")) - pool.add_schedule("refresh_cache", "*/10 * * * *", RefreshContext(scope="v2")) - members = fake_redis.zrange(_queue_key(), 0, -1) - assert len(members) == 1 +class TestScheduleRegister: + async def test_register_stores_in_redis(self, fake_redis): + await register( + task_name="refresh_cache", + every="*/5 * * * *", + context=RefreshContext(scope="all"), + ) - data = fake_redis.get(_schedule_key("refresh_cache")) - st = ScheduledTask.model_validate_json(data) - assert st.cron == "*/10 * * * *" + st = await _get_schedule(fake_redis, "refresh_cache") + assert st is not None + assert st.task_name == "refresh_cache" ctx = state.backend.deserialize(st.context) assert isinstance(ctx, RefreshContext) - assert ctx.scope == "v2" + assert ctx.scope == "all" + async def test_register_indexes_in_sorted_set(self, fake_redis): + await register( + task_name="refresh_cache", + every="*/5 * * * *", + context=RefreshContext(scope="all"), + ) -# --------------------------------------------------------------------------- -# @pool.schedule() decorator -# --------------------------------------------------------------------------- + members = await fake_redis.zrange(_index_key(), 0, -1, withscores=True) + assert len(members) == 1 class TestPoolScheduleDecorator: - def test_decorator_registers_task_and_schedule(self, fake_redis): - """@pool.schedule registers the task and schedules it.""" + def test_decorator_registers_task_and_defers_schedule(self): p = ax.Pool(database_url="sqlite:///") @p.schedule("refresh_cache", "*/5 * * * *", context=RefreshContext(scope="all")) async def refresh(agent_id: uuid.UUID, context: RefreshContext): pass - # Task is registered assert "refresh_cache" in p._context.tasks + assert len(p._pending_schedules) == 1 - # Schedule is in Redis - members = fake_redis.zrange(_queue_key(), 0, -1) - assert len(members) == 1 - - def test_decorator_without_context(self, fake_redis): - """@pool.schedule works without explicit context (defaults to empty BaseModel).""" - p = ax.Pool(database_url="sqlite:///") - - @p.schedule("simple_task", "0 * * * *") - async def simple(agent_id: uuid.UUID, context: BaseModel): - pass - - assert "simple_task" in p._context.tasks - members = fake_redis.zrange(_queue_key(), 0, -1) - assert len(members) == 1 - - def test_decorator_with_repeat(self, fake_redis): - """@pool.schedule passes repeat through.""" - p = ax.Pool(database_url="sqlite:///") - - @p.schedule("limited_task", "*/10 * * * *", context=RefreshContext(scope="all"), repeat=5) - async def limited(agent_id: uuid.UUID, context: RefreshContext): - pass - - data = fake_redis.get(_schedule_key("limited_task")) - st = ScheduledTask.model_validate_json(data) - assert st.repeat == 5 - - def test_decorator_with_lock_key(self, fake_redis): - """@pool.schedule passes lock_key to the task registration.""" + def test_decorator_with_lock_key(self): p = ax.Pool(database_url="sqlite:///") @p.schedule("locked_task", "*/5 * * * *", lock_key="user:{user_id}") @@ -277,8 +252,7 @@ async def locked(agent_id: uuid.UUID, context: RefreshContext): defn = p._context.tasks["locked_task"] assert defn.lock_key == "user:{user_id}" - def test_decorator_returns_handler(self, fake_redis): - """@pool.schedule returns the original handler function.""" + def test_decorator_returns_handler(self): p = ax.Pool(database_url="sqlite:///") @p.schedule("my_task", "*/5 * * * *") @@ -289,156 +263,120 @@ async def my_handler(agent_id: uuid.UUID, context: BaseModel): assert my_handler.__name__ == "my_handler" -# --------------------------------------------------------------------------- -# tick — the scheduler heartbeat -# --------------------------------------------------------------------------- - - -def _force_due(fake_redis, task_name): - """Helper: set a schedule's next_run to the past so tick() picks it up.""" - data = fake_redis.get(_schedule_key(task_name)) - st = ScheduledTask.model_validate_json(data) - st.next_run = time.time() - 10 - fake_redis.set(_schedule_key(task_name), st.model_dump_json().encode()) - fake_redis.zadd(_queue_key(), {task_name: st.next_run}) - return st - - class TestTick: - async def test_tick_enqueues_due_task(self, fake_redis, pool, mock_activity_create): - pool.add_schedule("refresh_cache", "*/5 * * * *", RefreshContext(scope="all")) - _force_due(fake_redis, "refresh_cache") + async def test_tick_enqueues_due_task(self, fake_redis, mock_activity_create): + await register("refresh_cache", "*/5 * * * *", RefreshContext(scope="all")) + await _force_due(fake_redis, "refresh_cache") await tick() - assert fake_redis.llen(ax.CONF.queue_name) == 1 + assert await fake_redis.llen(ax.CONF.queue_prefix) == 1 - async def test_tick_skips_future_tasks(self, fake_redis, pool, mock_activity_create): - pool.add_schedule("refresh_cache", "*/5 * * * *", RefreshContext(scope="all")) + async def test_tick_skips_future_tasks(self, fake_redis, mock_activity_create): + await register("refresh_cache", "*/5 * * * *", RefreshContext(scope="all")) await tick() - assert fake_redis.llen(ax.CONF.queue_name) == 0 + assert await fake_redis.llen(ax.CONF.queue_prefix) == 0 - async def test_tick_removes_one_shot_schedule(self, fake_redis, pool, mock_activity_create): - pool.add_schedule("refresh_cache", "* * * * *", RefreshContext(scope="all"), repeat=0) - _force_due(fake_redis, "refresh_cache") + async def test_tick_removes_one_shot_schedule(self, fake_redis, mock_activity_create): + await register("refresh_cache", "* * * * *", RefreshContext(scope="all"), repeat=0) + await _force_due(fake_redis, "refresh_cache") await tick() - assert fake_redis.get(_schedule_key("refresh_cache")) is None - assert fake_redis.zcard(_queue_key()) == 0 + assert await _get_schedule(fake_redis, "refresh_cache") is None + assert await fake_redis.zcard(_index_key()) == 0 - async def test_tick_decrements_repeat_count(self, fake_redis, pool, mock_activity_create): - pool.add_schedule("refresh_cache", "*/5 * * * *", RefreshContext(scope="all"), repeat=3) - old_st = _force_due(fake_redis, "refresh_cache") + async def test_tick_decrements_repeat_count(self, fake_redis, mock_activity_create): + await register("refresh_cache", "*/5 * * * *", RefreshContext(scope="all"), repeat=3) + old_st = await _force_due(fake_redis, "refresh_cache") await tick() - data = fake_redis.get(_schedule_key("refresh_cache")) - updated = ScheduledTask.model_validate_json(data) - assert updated.repeat == 2 + updated = await _get_schedule(fake_redis, "refresh_cache") + assert updated.repeat < 3 assert updated.next_run > old_st.next_run - async def test_tick_infinite_repeat_stays_negative(self, fake_redis, pool, mock_activity_create): - pool.add_schedule("refresh_cache", "*/5 * * * *", RefreshContext(scope="all")) - _force_due(fake_redis, "refresh_cache") + async def test_tick_infinite_repeat_stays_negative(self, fake_redis, mock_activity_create): + await register("refresh_cache", "*/5 * * * *", RefreshContext(scope="all")) + await _force_due(fake_redis, "refresh_cache") await tick() - data = fake_redis.get(_schedule_key("refresh_cache")) - updated = ScheduledTask.model_validate_json(data) + updated = await _get_schedule(fake_redis, "refresh_cache") assert updated.repeat == -1 - async def test_tick_anchor_based_rescheduling(self, fake_redis, pool, mock_activity_create): - pool.add_schedule("refresh_cache", "*/5 * * * *", RefreshContext(scope="all")) - old_st = _force_due(fake_redis, "refresh_cache") + async def test_tick_anchor_based_rescheduling(self, fake_redis, mock_activity_create): + await register("refresh_cache", "*/5 * * * *", RefreshContext(scope="all")) + old_st = await _force_due(fake_redis, "refresh_cache") await tick() - data = fake_redis.get(_schedule_key("refresh_cache")) - updated = ScheduledTask.model_validate_json(data) + updated = await _get_schedule(fake_redis, "refresh_cache") assert updated.next_run > old_st.next_run - async def test_tick_skips_orphaned_entries(self, fake_redis, pool, mock_activity_create): - """Orphaned queue entries are skipped (not deleted) with a warning.""" - fake_redis.zadd(_queue_key(), {"orphan-id": time.time() - 100}) + async def test_tick_skips_orphaned_entries(self, fake_redis, mock_activity_create): + """Orphaned index entries are skipped with a warning.""" + await fake_redis.zadd(_index_key(), {"orphan-id": time.time() - 100}) await tick() - assert fake_redis.zcard(_queue_key()) == 1 - assert fake_redis.llen(ax.CONF.queue_name) == 0 + assert await fake_redis.zcard(_index_key()) == 1 + assert await fake_redis.llen(ax.CONF.queue_prefix) == 0 - async def test_tick_skips_missed_intervals(self, fake_redis, pool, mock_activity_create): - """After downtime, advance() skips to the next future run — no burst of catch-up tasks.""" - pool.add_schedule("refresh_cache", "*/1 * * * *", RefreshContext(scope="all")) + async def test_tick_skips_missed_intervals(self, fake_redis, mock_activity_create): + """After downtime, advance() skips to the next future run.""" + await register("refresh_cache", "*/1 * * * *", RefreshContext(scope="all")) - # Simulate 10 minutes of downtime - data = fake_redis.get(_schedule_key("refresh_cache")) - st = ScheduledTask.model_validate_json(data) + st = await _get_schedule(fake_redis, "refresh_cache") st.next_run = time.time() - 600 - fake_redis.set(_schedule_key("refresh_cache"), st.model_dump_json().encode()) - fake_redis.zadd(_queue_key(), {"refresh_cache": st.next_run}) + await fake_redis.hset(_data_key(), st.key, st.model_dump_json().encode()) + await fake_redis.zadd(_index_key(), {st.key: st.next_run}) await tick() - assert fake_redis.llen(ax.CONF.queue_name) == 1 + assert await fake_redis.llen(ax.CONF.queue_prefix) == 1 - # Second tick should not enqueue again (next_run is in the future now) await tick() - assert fake_redis.llen(ax.CONF.queue_name) == 1 + assert await fake_redis.llen(ax.CONF.queue_prefix) == 1 - async def test_context_payload_preserved(self, fake_redis, pool, mock_activity_create): - pool.add_schedule("refresh_cache", "*/5 * * * *", RefreshContext(scope="users", ttl=999)) + async def test_context_payload_preserved(self, fake_redis): + await register("refresh_cache", "*/5 * * * *", RefreshContext(scope="users", ttl=999)) - data = fake_redis.get(_schedule_key("refresh_cache")) - st = ScheduledTask.model_validate_json(data) + st = await _get_schedule(fake_redis, "refresh_cache") ctx = state.backend.deserialize(st.context) assert isinstance(ctx, RefreshContext) assert ctx.scope == "users" assert ctx.ttl == 999 -# --------------------------------------------------------------------------- -# Timezone configuration -# --------------------------------------------------------------------------- - - class TestTimezone: def test_default_timezone_is_utc(self): - """Default should be UTC.""" from agentexec.config import CONF - assert CONF.scheduler_timezone == "UTC" def test_scheduler_tz_returns_zoneinfo(self): from agentexec.config import CONF - tz = CONF.scheduler_tz assert isinstance(tz, ZoneInfo) def test_cron_respects_configured_timezone(self, monkeypatch): - """Cron evaluation should use the configured timezone.""" from agentexec.config import CONF - monkeypatch.setattr(CONF, "scheduler_timezone", "America/New_York") ctx = RefreshContext(scope="test") st = ScheduledTask( task_name="test", context=state.backend.serialize(ctx), - cron="0 9 * * *", # 9 AM - + cron="0 9 * * *", ) - # Use a known timestamp: 2024-01-15 9:00 AM ET anchor = datetime(2024, 1, 15, 9, 0, 0, tzinfo=ZoneInfo("America/New_York")).timestamp() nxt = st._next_after(anchor) - # Next 9 AM ET should be ~24h later next_dt = datetime.fromtimestamp(nxt, tz=ZoneInfo("America/New_York")) assert next_dt.hour == 9 assert next_dt.day == 16 def test_timezone_env_override(self, monkeypatch): - """AGENTEXEC_SCHEDULER_TIMEZONE env var should override default.""" monkeypatch.setenv("AGENTEXEC_SCHEDULER_TIMEZONE", "Asia/Tokyo") from agentexec.config import Config diff --git a/tests/test_self_describing_results.py b/tests/test_self_describing_results.py index acf56d6..923cdae 100644 --- a/tests/test_self_describing_results.py +++ b/tests/test_self_describing_results.py @@ -1,94 +1,73 @@ -"""Test self-describing result serialization (pickle-like behavior with JSON).""" - import uuid import pytest from pydantic import BaseModel import agentexec as ax -from agentexec import state +from agentexec.state import KEY_RESULT, backend class DummyContext(BaseModel): - """Dummy context for testing.""" - pass class ResearchResult(BaseModel): - """Sample result model.""" - company: str valuation: int class AnalysisResult(BaseModel): - """Another result model.""" - conclusion: str confidence: float class NestedData(BaseModel): - """Nested data structure for testing.""" - items: list[str] metadata: dict[str, int] class ComplexResult(BaseModel): - """Complex result with nested structure.""" - status: str data: NestedData async def test_gather_without_task_definitions(monkeypatch) -> None: - """Test that gather() works without needing TaskDefinitions. - - This demonstrates that results are self-describing - they include - their type information, so we can deserialize without a registry. - """ - # Create tasks without TaskDefinitions (as enqueue() does) + """Test that gather() works without needing TaskDefinitions.""" task1 = ax.Task( task_name="research", - context=DummyContext(), + context={}, agent_id=uuid.uuid4(), ) task2 = ax.Task( task_name="analysis", - context=DummyContext(), + context={}, agent_id=uuid.uuid4(), ) - # Store results with type information result1 = ResearchResult(company="Anthropic", valuation=1000000) result2 = AnalysisResult(conclusion="Strong", confidence=0.95) # Mock backend storage storage = {} - def mock_format_key(*args): - return ":".join(args) - - async def mock_aset(key, value, ttl_seconds=None): + async def mock_state_set(key, value, ttl_seconds=None): storage[key] = value return True - async def mock_aget(key): + async def mock_state_get(key): return storage.get(key) - monkeypatch.setattr(state.backend, "format_key", mock_format_key) - monkeypatch.setattr(state.backend, "aset", mock_aset) - monkeypatch.setattr(state.backend, "aget", mock_aget) + monkeypatch.setattr(backend.state, "set", mock_state_set) + monkeypatch.setattr(backend.state, "get", mock_state_get) - await state.aset_result(task1.agent_id, result1) - await state.aset_result(task2.agent_id, result2) + # Store results via the same path task.execute() would + for task, result in [(task1, result1), (task2, result2)]: + key = backend.format_key(*KEY_RESULT, str(task.agent_id)) + await backend.state.set(key, backend.serialize(result)) - # Gather results - no TaskDefinition needed! + # Gather results results = await ax.gather(task1, task2) - # Results are correctly typed assert isinstance(results[0], ResearchResult) assert isinstance(results[1], AnalysisResult) assert results[0].company == "Anthropic" @@ -99,11 +78,8 @@ async def test_result_roundtrip_preserves_type() -> None: """Test that serialize → deserialize preserves exact type.""" original = ResearchResult(company="Acme", valuation=500000) - # Serialize - serialized = state.backend.serialize(original) - - # Deserialize - should get back the same type - deserialized = state.backend.deserialize(serialized) + serialized = backend.serialize(original) + deserialized = backend.deserialize(serialized) assert type(deserialized) is ResearchResult assert deserialized == original @@ -116,9 +92,8 @@ async def test_nested_models_preserve_structure() -> None: data=NestedData(items=["a", "b"], metadata={"count": 2}), ) - # Roundtrip - serialized = state.backend.serialize(original) - deserialized = state.backend.deserialize(serialized) + serialized = backend.serialize(original) + deserialized = backend.deserialize(serialized) assert type(deserialized) is ComplexResult assert type(deserialized.data) is NestedData diff --git a/tests/test_state.py b/tests/test_state.py index 1a54e0d..4f53b92 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -1,185 +1,59 @@ -"""Tests for state module public API.""" - -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, patch import pytest from pydantic import BaseModel -from agentexec import state +from agentexec.state import KEY_RESULT, backend -# Test models for result serialization class ResultModel(BaseModel): - """Test result model.""" - status: str value: int -class OutputModel(BaseModel): - """Test output model.""" +class TestSerialization: + """Tests for serialize/deserialize on the backend.""" - status: str - output: str + def test_roundtrip(self): + model = ResultModel(status="success", value=42) + data = backend.serialize(model) + restored = backend.deserialize(data) + assert isinstance(restored, ResultModel) + assert restored == model -class TestResultOperations: - """Tests for result get/set/delete operations.""" +class TestFormatKey: + """Tests for key formatting.""" - def test_get_result_found(self): - """Test getting an existing result returns deserialized BaseModel.""" - result_model = ResultModel(status="success", value=42) - # Serialize with type information (mimicking backend.serialize) - serialized = state.backend.serialize(result_model) + def test_result_key(self): + key = backend.format_key(*KEY_RESULT, "agent-123") + assert "result" in key + assert "agent-123" in key - with patch.object(state.backend, "get", return_value=serialized) as mock_get: - result = state.get_result("agent123") - mock_get.assert_called_once_with("agentexec:result:agent123") - # Result should be deserialized BaseModel - assert isinstance(result, ResultModel) - assert result == result_model - def test_get_result_not_found(self): - """Test getting a non-existent result returns None.""" - with patch.object(state.backend, "get", return_value=None) as mock_get: - result = state.get_result("agent456") - - mock_get.assert_called_once_with("agentexec:result:agent456") - assert result is None +class TestStateBackend: + """Tests for state.get/set/delete via backend.state.""" - async def test_aget_result_found(self): - """Test async getting an existing result returns deserialized BaseModel.""" - result_model = OutputModel(status="complete", output="test") - serialized = state.backend.serialize(result_model) + async def test_set_and_get(self): + result = ResultModel(status="success", value=42) + serialized = backend.serialize(result) - async def mock_aget(key): + async def mock_get(key): return serialized - with patch.object(state.backend, "aget", side_effect=mock_aget): - result = await state.aget_result("agent789") + with patch.object(backend.state, "get", side_effect=mock_get): + data = await backend.state.get("test-key") + restored = backend.deserialize(data) + assert isinstance(restored, ResultModel) + assert restored == result - # Result should be deserialized BaseModel - assert isinstance(result, OutputModel) - assert result == result_model - - async def test_aget_result_not_found(self): - """Test async getting a non-existent result.""" - async def mock_aget(key): + async def test_get_missing(self): + async def mock_get(key): return None - with patch.object(state.backend, "aget", side_effect=mock_aget): - result = await state.aget_result("missing") - + with patch.object(backend.state, "get", side_effect=mock_get): + result = await backend.state.get("missing-key") assert result is None - def test_set_result_without_ttl(self): - """Test setting a result without TTL.""" - result_model = ResultModel(status="success", value=42) - - with patch.object(state.backend, "set", return_value=True) as mock_set: - success = state.set_result("agent123", result_model) - - mock_set.assert_called_once() - call_args = mock_set.call_args - assert call_args[0][0] == "agentexec:result:agent123" - # Should be JSON bytes with type information - stored_value = call_args[0][1] - assert isinstance(stored_value, bytes) - # Verify it can be deserialized back - deserialized = state.backend.deserialize(stored_value) - assert isinstance(deserialized, ResultModel) - assert deserialized == result_model - assert call_args[1]["ttl_seconds"] is None - assert success is True - - def test_set_result_with_ttl(self): - """Test setting a result with TTL.""" - result_model = ResultModel(status="success", value=100) - - with patch.object(state.backend, "set", return_value=True) as mock_set: - success = state.set_result("agent456", result_model, ttl_seconds=3600) - - call_args = mock_set.call_args - assert call_args[0][0] == "agentexec:result:agent456" - assert call_args[1]["ttl_seconds"] == 3600 - assert success is True - - async def test_aset_result(self): - """Test async setting a result.""" - result_model = OutputModel(status="complete", output="test") - - async def mock_aset(key, value, ttl_seconds=None): - return True - - with patch.object(state.backend, "aset", side_effect=mock_aset): - success = await state.aset_result("agent789", result_model, ttl_seconds=7200) - - assert success is True - - def test_delete_result(self): - """Test deleting a result.""" - with patch.object(state.backend, "delete", return_value=1) as mock_delete: - count = state.delete_result("agent123") - - mock_delete.assert_called_once_with("agentexec:result:agent123") - assert count == 1 - - async def test_adelete_result(self): - """Test async deleting a result.""" - async def mock_adelete(key): - return 1 - - with patch.object(state.backend, "adelete", side_effect=mock_adelete): - count = await state.adelete_result("agent456") - - assert count == 1 - - -class TestLogOperations: - """Tests for log pub/sub operations.""" - - def test_publish_log(self): - """Test publishing a log message.""" - log_message = '{"level": "info", "message": "test log"}' - - with patch.object(state.backend, "publish") as mock_publish: - state.publish_log(log_message) - - mock_publish.assert_called_once_with("agentexec:logs", log_message) - - async def test_subscribe_logs(self): - """Test subscribing to logs.""" - log_messages = [ - '{"level": "info", "message": "log1"}', - '{"level": "error", "message": "log2"}' - ] - - async def mock_subscribe(channel): - for msg in log_messages: - yield msg - - with patch.object(state.backend, "subscribe", side_effect=mock_subscribe): - messages = [] - async for msg in state.subscribe_logs(): - messages.append(msg) - - assert messages == log_messages - - -class TestKeyGeneration: - """Tests for key generation with format_key.""" - - def test_result_key_format(self): - """Test that result keys are formatted correctly.""" - with patch.object(state.backend, "get", return_value=None) as mock_get: - state.get_result("test-id") - - mock_get.assert_called_once_with("agentexec:result:test-id") - - def test_logs_channel_format(self): - """Test that log channel is formatted correctly.""" - with patch.object(state.backend, "publish") as mock_publish: - state.publish_log("test") - mock_publish.assert_called_once_with("agentexec:logs", "test") diff --git a/tests/test_state_backend.py b/tests/test_state_backend.py index 3c00787..464ce46 100644 --- a/tests/test_state_backend.py +++ b/tests/test_state_backend.py @@ -1,292 +1,120 @@ -"""Tests for state backend module.""" - -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, MagicMock import pytest from pydantic import BaseModel -from agentexec.state import redis_backend +from agentexec.state import backend +from agentexec.state.redis import Backend as RedisBackend class SampleModel(BaseModel): - """Sample model for serialization tests.""" - status: str value: int class NestedModel(BaseModel): - """Model with nested structure for serialization tests.""" - items: list[int] metadata: dict[str, str] -@pytest.fixture(autouse=True) -def reset_redis_clients(): - """Reset Redis client state before and after each test.""" - redis_backend._redis_client = None - redis_backend._redis_sync_client = None - redis_backend._pubsub = None - yield - redis_backend._redis_client = None - redis_backend._redis_sync_client = None - redis_backend._pubsub = None - - -@pytest.fixture -def mock_sync_client(): - """Mock synchronous Redis client.""" - client = MagicMock() - with patch.object(redis_backend, "_get_sync_client", return_value=client): - yield client - - @pytest.fixture -def mock_async_client(): - """Mock asynchronous Redis client.""" +def mock_client(monkeypatch): + """Inject a mock async Redis client into the backend.""" client = AsyncMock() - with patch.object(redis_backend, "_get_async_client", return_value=client): - yield client + monkeypatch.setattr(backend, "_client", client) + yield client class TestFormatKey: - """Tests for format_key function.""" - def test_format_single_part(self): - """Test formatting key with single part.""" - result = redis_backend.format_key("result") - assert result == "result" + assert backend.format_key("result") == "result" def test_format_multiple_parts(self): - """Test formatting key with multiple parts.""" - result = redis_backend.format_key("agentexec", "result", "123") - assert result == "agentexec:result:123" + assert backend.format_key("agentexec", "result", "123") == "agentexec:result:123" def test_format_empty_parts(self): - """Test formatting with no parts returns empty string.""" - result = redis_backend.format_key() - assert result == "" + assert backend.format_key() == "" class TestSerialization: - """Tests for serialize and deserialize functions.""" - def test_serialize_basemodel(self): - """Test serializing a BaseModel.""" data = SampleModel(status="success", value=42) - result = redis_backend.serialize(data) + result = backend.serialize(data) assert isinstance(result, bytes) - def test_serialize_rejects_dict(self): - """Test that serialize rejects raw dicts.""" - with pytest.raises(TypeError, match="Expected BaseModel"): - redis_backend.serialize({"key": "value"}) # type: ignore[arg-type] - - def test_serialize_rejects_list(self): - """Test that serialize rejects raw lists.""" - with pytest.raises(TypeError, match="Expected BaseModel"): - redis_backend.serialize([1, 2, 3]) # type: ignore[arg-type] - def test_serialize_deserialize_roundtrip(self): - """Test serialize then deserialize returns equivalent model.""" data = SampleModel(status="success", value=42) - serialized = redis_backend.serialize(data) - deserialized = redis_backend.deserialize(serialized) + serialized = backend.serialize(data) + deserialized = backend.deserialize(serialized) assert isinstance(deserialized, SampleModel) - assert deserialized.status == data.status - assert deserialized.value == data.value + assert deserialized == data def test_serialize_deserialize_nested_model(self): - """Test roundtrip with nested structures.""" data = NestedModel(items=[1, 2, 3], metadata={"key": "value"}) - serialized = redis_backend.serialize(data) - deserialized = redis_backend.deserialize(serialized) + serialized = backend.serialize(data) + deserialized = backend.deserialize(serialized) assert isinstance(deserialized, NestedModel) - assert deserialized.items == data.items - assert deserialized.metadata == data.metadata - - -class TestQueueOperations: - """Tests for queue operations (rpush, lpush, brpop).""" - - def test_rpush(self, mock_sync_client): - """Test rpush adds to right of list.""" - mock_sync_client.rpush.return_value = 5 - - result = redis_backend.rpush("tasks", "task_data") - - mock_sync_client.rpush.assert_called_once_with("tasks", "task_data") - assert result == 5 - - def test_lpush(self, mock_sync_client): - """Test lpush adds to left of list.""" - mock_sync_client.lpush.return_value = 3 - - result = redis_backend.lpush("tasks", "task_data") - - mock_sync_client.lpush.assert_called_once_with("tasks", "task_data") - assert result == 3 - - async def test_brpop_with_result(self, mock_async_client): - """Test brpop returns decoded result.""" - mock_async_client.brpop.return_value = (b"tasks", b"task_value") - - result = await redis_backend.brpop("tasks", timeout=5) - - mock_async_client.brpop.assert_called_once_with(["tasks"], timeout=5) - assert result == ("tasks", "task_value") - - async def test_brpop_timeout(self, mock_async_client): - """Test brpop returns None on timeout.""" - mock_async_client.brpop.return_value = None - - result = await redis_backend.brpop("tasks", timeout=1) - - assert result is None + assert deserialized == data class TestKeyValueOperations: - """Tests for get/set/delete operations.""" - - def test_get_sync(self, mock_sync_client): - """Test synchronous get.""" - mock_sync_client.get.return_value = b"value" - - result = redis_backend.get("mykey") - - mock_sync_client.get.assert_called_once_with("mykey") + async def test_get(self, mock_client): + mock_client.get.return_value = b"value" + result = await backend.state.get("mykey") + mock_client.get.assert_called_once_with("mykey") assert result == b"value" - def test_get_sync_missing_key(self, mock_sync_client): - """Test get returns None for missing key.""" - mock_sync_client.get.return_value = None - - result = redis_backend.get("missing") - + async def test_get_missing_key(self, mock_client): + mock_client.get.return_value = None + result = await backend.state.get("missing") assert result is None - async def test_aget(self, mock_async_client): - """Test asynchronous get.""" - mock_async_client.get.return_value = b"async_value" - - result = await redis_backend.aget("mykey") - - mock_async_client.get.assert_called_once_with("mykey") - assert result == b"async_value" - - def test_set_sync(self, mock_sync_client): - """Test synchronous set without TTL.""" - mock_sync_client.set.return_value = True - - result = redis_backend.set("mykey", b"value") - - mock_sync_client.set.assert_called_once_with("mykey", b"value") + async def test_set_without_ttl(self, mock_client): + mock_client.set.return_value = True + result = await backend.state.set("mykey", b"value") + mock_client.set.assert_called_once_with("mykey", b"value") assert result is True - def test_set_sync_with_ttl(self, mock_sync_client): - """Test synchronous set with TTL.""" - mock_sync_client.set.return_value = True - - result = redis_backend.set("mykey", b"value", ttl_seconds=3600) - - mock_sync_client.set.assert_called_once_with("mykey", b"value", ex=3600) - assert result is True - - async def test_aset(self, mock_async_client): - """Test asynchronous set with TTL.""" - mock_async_client.set.return_value = True - - result = await redis_backend.aset("mykey", b"value", ttl_seconds=7200) - - mock_async_client.set.assert_called_once_with("mykey", b"value", ex=7200) + async def test_set_with_ttl(self, mock_client): + mock_client.set.return_value = True + result = await backend.state.set("mykey", b"value", ttl_seconds=3600) + mock_client.set.assert_called_once_with("mykey", b"value", ex=3600) assert result is True - def test_delete_sync(self, mock_sync_client): - """Test synchronous delete.""" - mock_sync_client.delete.return_value = 1 - - result = redis_backend.delete("mykey") - - mock_sync_client.delete.assert_called_once_with("mykey") + async def test_delete(self, mock_client): + mock_client.delete.return_value = 1 + result = await backend.state.delete("mykey") + mock_client.delete.assert_called_once_with("mykey") assert result == 1 - async def test_adelete(self, mock_async_client): - """Test asynchronous delete.""" - mock_async_client.delete.return_value = 1 - result = await redis_backend.adelete("mykey") - - mock_async_client.delete.assert_called_once_with("mykey") - assert result == 1 - - -class TestPubSubOperations: - """Tests for pub/sub operations.""" - - def test_publish(self, mock_sync_client): - """Test publishing message to channel.""" - redis_backend.publish("logs", "log message") - - mock_sync_client.publish.assert_called_once_with("logs", "log message") - - async def test_subscribe(self, mock_async_client): - """Test subscribing to channel.""" - mock_pubsub = AsyncMock() - # Make pubsub() return the mock directly (not a coroutine) - mock_async_client.pubsub = MagicMock(return_value=mock_pubsub) - - # Create async iterator for messages - async def mock_listen(): - yield {"type": "subscribe"} - yield {"type": "message", "data": b"message1"} - yield {"type": "message", "data": "message2"} - - # Make listen() return the generator directly (not wrapped in AsyncMock) - mock_pubsub.listen = MagicMock(return_value=mock_listen()) - - messages = [] - async for msg in redis_backend.subscribe("test_channel"): - messages.append(msg) +class TestCounterOperations: + async def test_counter_incr(self, mock_client): + mock_client.incr.return_value = 5 + result = await backend.state.counter_incr("mycount") + mock_client.incr.assert_called_once_with("mycount") + assert result == 5 - assert messages == ["message1", "message2"] - mock_pubsub.subscribe.assert_called_once_with("test_channel") - mock_pubsub.unsubscribe.assert_called_once_with("test_channel") - mock_pubsub.close.assert_called_once() + async def test_counter_decr(self, mock_client): + mock_client.decr.return_value = 3 + result = await backend.state.counter_decr("mycount") + mock_client.decr.assert_called_once_with("mycount") + assert result == 3 class TestConnectionManagement: - """Tests for connection lifecycle.""" - async def test_close_all_connections(self): - """Test close cleans up all resources.""" - # Set up mock clients - mock_async = AsyncMock() - mock_sync = MagicMock() - mock_ps = AsyncMock() - - redis_backend._redis_client = mock_async - redis_backend._redis_sync_client = mock_sync - redis_backend._pubsub = mock_ps - - await redis_backend.close() + mock_client = AsyncMock() + backend._client = mock_client - mock_ps.close.assert_called_once() - mock_async.aclose.assert_called_once() - mock_sync.close.assert_called_once() + await backend.close() - assert redis_backend._redis_client is None - assert redis_backend._redis_sync_client is None - assert redis_backend._pubsub is None + mock_client.aclose.assert_called_once() + assert backend._client is None async def test_close_handles_none_clients(self): - """Test close handles None clients gracefully.""" - redis_backend._redis_client = None - redis_backend._redis_sync_client = None - redis_backend._pubsub = None + backend._client = None - # Should not raise - await redis_backend.close() + await backend.close() - assert redis_backend._redis_client is None - assert redis_backend._redis_sync_client is None + assert backend._client is None diff --git a/tests/test_task.py b/tests/test_task.py index 52c4dfd..ebc7725 100644 --- a/tests/test_task.py +++ b/tests/test_task.py @@ -1,5 +1,3 @@ -"""Test Task data structure and serialization.""" - import json import uuid @@ -7,44 +5,36 @@ from pydantic import BaseModel import agentexec as ax +from agentexec.core.task import TaskDefinition class SampleContext(BaseModel): - """Sample context for task tests.""" - message: str value: int = 0 class NestedContext(BaseModel): - """Sample context with nested data.""" - message: str nested: dict class TaskResult(BaseModel): - """Sample result model for task tests.""" - status: str @pytest.fixture def pool(): - """Create a Pool for testing.""" from sqlalchemy import create_engine - engine = create_engine("sqlite:///:memory:") return ax.Pool(engine=engine) def test_task_serialization() -> None: - """Test that tasks can be serialized to JSON.""" + """Task serializes to JSON with context as a dict.""" agent_id = uuid.uuid4() - ctx = SampleContext(message="hello", value=42) task = ax.Task( task_name="test_task", - context=ctx, + context={"message": "hello", "value": 42}, agent_id=agent_id, ) @@ -57,15 +47,8 @@ def test_task_serialization() -> None: assert task_data["agent_id"] == str(agent_id) -def test_task_deserialization(pool) -> None: - """Test that tasks can be deserialized using Task.from_serialized.""" - # Register a task to get a TaskDefinition - @pool.task("test_task") - async def handler(agent_id: uuid.UUID, context: SampleContext) -> TaskResult: - return TaskResult(status="success") - - task_def = pool._context.tasks["test_task"] - +def test_task_deserialization() -> None: + """Task deserializes from raw queue data.""" agent_id = uuid.uuid4() data = { "task_name": "test_task", @@ -73,124 +56,82 @@ async def handler(agent_id: uuid.UUID, context: SampleContext) -> TaskResult: "agent_id": str(agent_id), } - task = ax.Task.from_serialized(task_def, data) + task = ax.Task.model_validate(data) assert task.task_name == "test_task" - assert isinstance(task.context, SampleContext) - assert task.context.message == "hello" - assert task.context.value == 42 + assert task.context == {"message": "hello", "value": 42} assert task.agent_id == agent_id -def test_task_round_trip(pool) -> None: - """Test that tasks can be serialized and deserialized.""" - # Register task for deserialization - @pool.task("round_trip_task") - async def handler(agent_id: uuid.UUID, context: NestedContext) -> TaskResult: - return TaskResult(status="success") - - task_def = pool._context.tasks["round_trip_task"] - - original_ctx = NestedContext(message="hello", nested={"key": "value"}) +def test_task_round_trip() -> None: + """Task survives serialize → JSON → deserialize.""" original = ax.Task( task_name="round_trip_task", - context=original_ctx, + context={"message": "hello", "nested": {"key": "value"}}, agent_id=uuid.uuid4(), ) - # Serialize → JSON → Deserialize serialized = original.model_dump_json() - data = json.loads(serialized) - deserialized = ax.Task.from_serialized(task_def, data) + deserialized = ax.Task.model_validate_json(serialized) assert deserialized.task_name == original.task_name - # Cast to access typed attributes (Task.context is typed as BaseModel) - assert isinstance(deserialized.context, NestedContext) - assert isinstance(original.context, NestedContext) - assert deserialized.context.message == original.context.message - assert deserialized.context.nested == original.context.nested + assert deserialized.context == original.context assert deserialized.agent_id == original.agent_id -def test_task_create_with_basemodel(monkeypatch) -> None: - """Test Task.create() with a BaseModel context.""" - # Mock activity.create to avoid database dependency - def mock_create(*args, **kwargs): +async def test_task_create_with_basemodel(monkeypatch) -> None: + """Task.create() serializes context to dict.""" + async def mock_create(*args, **kwargs): return uuid.uuid4() monkeypatch.setattr("agentexec.core.task.activity.create", mock_create) ctx = SampleContext(message="hello", value=42) - task = ax.Task.create("test_task", ctx) + task = await ax.Task.create("test_task", ctx) assert task.task_name == "test_task" - # Context is the typed object - assert isinstance(task.context, SampleContext) - assert task.context.message == "hello" - assert task.context.value == 42 + assert task.context == {"message": "hello", "value": 42} -def test_task_create_preserves_nested(monkeypatch) -> None: - """Test Task.create() preserves nested Pydantic models.""" - # Mock activity.create to avoid database dependency - def mock_create(*args, **kwargs): +async def test_task_create_preserves_nested(monkeypatch) -> None: + """Task.create() preserves nested structures in the dict.""" + async def mock_create(*args, **kwargs): return uuid.uuid4() monkeypatch.setattr("agentexec.core.task.activity.create", mock_create) ctx = NestedContext(message="hello", nested={"key": "value"}) - task = ax.Task.create("test_task", ctx) + task = await ax.Task.create("test_task", ctx) - assert isinstance(task.context, NestedContext) - assert task.context.message == "hello" - assert task.context.nested == {"key": "value"} + assert task.context == {"message": "hello", "nested": {"key": "value"}} -def test_task_from_serialized(pool) -> None: - """Test Task.from_serialized creates a task with typed context.""" - from agentexec.core.task import TaskDefinition - +def test_definition_hydrates_context(pool) -> None: + """TaskDefinition.hydrate_context validates dict into typed model.""" @pool.task("test_task") async def handler(agent_id: uuid.UUID, context: SampleContext) -> TaskResult: - return TaskResult(status=f"Result: {context.message}") - - task_def = pool._context.tasks["test_task"] - agent_id = uuid.uuid4() - - data = { - "task_name": "test_task", - "context": {"message": "hello", "value": 42}, - "agent_id": str(agent_id), - } - - task = ax.Task.from_serialized(task_def, data) + return TaskResult(status="success") - assert task.task_name == "test_task" - assert isinstance(task.context, SampleContext) - assert task.context.message == "hello" - assert task.context.value == 42 - assert task.agent_id == agent_id - assert task._definition is task_def + definition = pool._context.tasks["test_task"] + typed = definition.hydrate_context({"message": "hello", "value": 42}) + assert isinstance(typed, SampleContext) + assert typed.message == "hello" + assert typed.value == 42 -async def test_task_execute_async_handler(pool, monkeypatch) -> None: - """Test Task.execute with an async handler.""" - from unittest.mock import AsyncMock - # Track activity updates +async def test_definition_execute_async(pool, monkeypatch) -> None: + """TaskDefinition.execute() runs async handler and tracks activity.""" activity_updates = [] - def mock_update(**kwargs): + async def mock_update(**kwargs): activity_updates.append(kwargs) - # Mock state.aset_result - aset_result_calls = [] - - async def mock_aset_result(agent_id, data, ttl_seconds=None): - aset_result_calls.append((agent_id, data, ttl_seconds)) + async def mock_state_set(key, value, ttl_seconds=None): + pass monkeypatch.setattr("agentexec.core.task.activity.update", mock_update) - monkeypatch.setattr("agentexec.core.task.state.aset_result", mock_aset_result) + monkeypatch.setattr("agentexec.core.task.backend.state.set", mock_state_set) execution_result = TaskResult(status="success") @@ -198,64 +139,46 @@ async def mock_aset_result(agent_id, data, ttl_seconds=None): async def async_handler(agent_id: uuid.UUID, context: SampleContext) -> TaskResult: return execution_result - task_def = pool._context.tasks["async_task"] - agent_id = uuid.uuid4() - - task = ax.Task.from_serialized( - task_def, - { - "task_name": "async_task", - "context": {"message": "test"}, - "agent_id": str(agent_id), - }, + definition = pool._context.tasks["async_task"] + task = ax.Task( + task_name="async_task", + context={"message": "test"}, + agent_id=uuid.uuid4(), ) - result = await task.execute() + result = await definition.execute(task) assert result == execution_result - # Verify activity was updated (started and completed) assert len(activity_updates) == 2 - # First update marks task as started assert activity_updates[0]["percentage"] == 0 - # Second update marks task as completed assert activity_updates[1]["percentage"] == 100 - # Verify result was stored - assert len(aset_result_calls) == 1 - assert aset_result_calls[0][0] == agent_id # Can be UUID or str - assert aset_result_calls[0][1] == execution_result - -async def test_task_execute_sync_handler(pool, monkeypatch) -> None: - """Test Task.execute with a sync handler.""" +async def test_definition_execute_sync(pool, monkeypatch) -> None: + """TaskDefinition.execute() runs sync handler.""" activity_updates = [] - def mock_update(**kwargs): + async def mock_update(**kwargs): activity_updates.append(kwargs) - async def mock_aset_result(agent_id, data, ttl_seconds=None): + async def mock_state_set(key, value, ttl_seconds=None): pass monkeypatch.setattr("agentexec.core.task.activity.update", mock_update) - monkeypatch.setattr("agentexec.core.task.state.aset_result", mock_aset_result) + monkeypatch.setattr("agentexec.core.task.backend.state.set", mock_state_set) @pool.task("sync_task") def sync_handler(agent_id: uuid.UUID, context: SampleContext) -> TaskResult: return TaskResult(status=f"Sync result: {context.message}") - task_def = pool._context.tasks["sync_task"] - agent_id = uuid.uuid4() - - task = ax.Task.from_serialized( - task_def, - { - "task_name": "sync_task", - "context": {"message": "test"}, - "agent_id": str(agent_id), - }, + definition = pool._context.tasks["sync_task"] + task = ax.Task( + task_name="sync_task", + context={"message": "test"}, + agent_id=uuid.uuid4(), ) - result = await task.execute() + result = await definition.execute(task) assert result is not None assert isinstance(result, TaskResult) @@ -263,54 +186,193 @@ def sync_handler(agent_id: uuid.UUID, context: SampleContext) -> TaskResult: assert len(activity_updates) == 2 -async def test_task_execute_without_definition_raises() -> None: - """Test Task.execute raises RuntimeError if not bound to definition.""" - task = ax.Task( - task_name="test_task", - context=SampleContext(message="test"), - agent_id=uuid.uuid4(), - ) - - with pytest.raises(RuntimeError, match="must be bound to a definition"): - await task.execute() - - -async def test_task_execute_error_marks_activity_errored(pool, monkeypatch) -> None: - """Test Task.execute marks activity as errored on exception.""" - from agentexec.activity.models import Status +async def test_definition_execute_error(pool, monkeypatch) -> None: + """TaskDefinition.execute() marks activity as errored on exception.""" + from agentexec.activity.status import Status activity_updates = [] - def mock_update(**kwargs): + async def mock_update(**kwargs): activity_updates.append(kwargs) - async def mock_aset_result(agent_id, data, ttl_seconds=None): + async def mock_state_set(key, value, ttl_seconds=None): pass monkeypatch.setattr("agentexec.core.task.activity.update", mock_update) - monkeypatch.setattr("agentexec.core.task.state.aset_result", mock_aset_result) + monkeypatch.setattr("agentexec.core.task.backend.state.set", mock_state_set) @pool.task("failing_task") async def failing_handler(agent_id: uuid.UUID, context: SampleContext) -> TaskResult: raise ValueError("Task failed!") - task_def = pool._context.tasks["failing_task"] - agent_id = uuid.uuid4() - - task = ax.Task.from_serialized( - task_def, - { - "task_name": "failing_task", - "context": {"message": "test"}, - "agent_id": str(agent_id), - }, + definition = pool._context.tasks["failing_task"] + task = ax.Task( + task_name="failing_task", + context={"message": "test"}, + agent_id=uuid.uuid4(), ) - # execute() catches the exception and marks activity as errored, returns None - result = await task.execute() + with pytest.raises(ValueError, match="Task failed!"): + await definition.execute(task) - assert result is None # Handler exception results in None return - # First update marks started, second marks errored assert len(activity_updates) == 2 assert activity_updates[1]["status"] == Status.ERROR assert "Task failed!" in activity_updates[1]["message"] + + +async def test_definition_execute_none_result_not_stored(pool, monkeypatch) -> None: + """Handler returning None does not write to result storage.""" + state_set_calls = [] + + async def mock_update(**kwargs): + pass + + async def mock_state_set(key, value, ttl_seconds=None): + state_set_calls.append(key) + + monkeypatch.setattr("agentexec.core.task.activity.update", mock_update) + monkeypatch.setattr("agentexec.core.task.backend.state.set", mock_state_set) + + @pool.task("void_task") + async def void_handler(agent_id: uuid.UUID, context: SampleContext) -> None: + pass + + definition = pool._context.tasks["void_task"] + task = ax.Task( + task_name="void_task", + context={"message": "test"}, + agent_id=uuid.uuid4(), + ) + + result = await definition.execute(task) + assert result is None + assert len(state_set_calls) == 0 + + +async def test_definition_execute_stores_result_with_ttl(pool, monkeypatch) -> None: + """Handler result is stored in state with the configured TTL.""" + from agentexec.config import CONF + + state_set_calls = [] + + async def mock_update(**kwargs): + pass + + async def mock_state_set(key, value, ttl_seconds=None): + state_set_calls.append({"key": key, "ttl_seconds": ttl_seconds}) + + monkeypatch.setattr("agentexec.core.task.activity.update", mock_update) + monkeypatch.setattr("agentexec.core.task.backend.state.set", mock_state_set) + + @pool.task("result_task") + async def result_handler(agent_id: uuid.UUID, context: SampleContext) -> TaskResult: + return TaskResult(status="done") + + definition = pool._context.tasks["result_task"] + task = ax.Task( + task_name="result_task", + context={"message": "test"}, + agent_id=uuid.uuid4(), + ) + + await definition.execute(task) + assert len(state_set_calls) == 1 + assert state_set_calls[0]["ttl_seconds"] == CONF.result_ttl + assert str(task.agent_id) in state_set_calls[0]["key"] + + +async def test_definition_execute_hydrates_context(pool, monkeypatch) -> None: + """execute() passes a typed context model to the handler, not a raw dict.""" + received_context = [] + + async def mock_update(**kwargs): + pass + + monkeypatch.setattr("agentexec.core.task.activity.update", mock_update) + + @pool.task("typed_task") + async def typed_handler(agent_id: uuid.UUID, context: SampleContext) -> None: + received_context.append(context) + + definition = pool._context.tasks["typed_task"] + task = ax.Task( + task_name="typed_task", + context={"message": "typed", "value": 7}, + agent_id=uuid.uuid4(), + ) + + await definition.execute(task) + assert len(received_context) == 1 + assert isinstance(received_context[0], SampleContext) + assert received_context[0].message == "typed" + assert received_context[0].value == 7 + + +async def test_definition_execute_passes_agent_id(pool, monkeypatch) -> None: + """execute() passes the task's agent_id to the handler.""" + received_ids = [] + + async def mock_update(**kwargs): + pass + + monkeypatch.setattr("agentexec.core.task.activity.update", mock_update) + + @pool.task("id_task") + async def id_handler(agent_id: uuid.UUID, context: SampleContext) -> None: + received_ids.append(agent_id) + + definition = pool._context.tasks["id_task"] + expected_id = uuid.uuid4() + task = ax.Task( + task_name="id_task", + context={"message": "test"}, + agent_id=expected_id, + ) + + await definition.execute(task) + assert received_ids == [expected_id] + + +async def test_definition_execute_error_reraises(pool, monkeypatch) -> None: + """execute() re-raises the original exception after marking activity as errored.""" + async def mock_update(**kwargs): + pass + + monkeypatch.setattr("agentexec.core.task.activity.update", mock_update) + + @pool.task("reraise_task") + async def bad_handler(agent_id: uuid.UUID, context: SampleContext): + raise RuntimeError("original error") + + definition = pool._context.tasks["reraise_task"] + task = ax.Task( + task_name="reraise_task", + context={"message": "test"}, + agent_id=uuid.uuid4(), + ) + + with pytest.raises(RuntimeError, match="original error"): + await definition.execute(task) + + +async def test_definition_execute_bad_context_raises(pool, monkeypatch) -> None: + """execute() raises ValidationError when context doesn't match the registered type.""" + async def mock_update(**kwargs): + pass + + monkeypatch.setattr("agentexec.core.task.activity.update", mock_update) + + @pool.task("strict_task") + async def strict_handler(agent_id: uuid.UUID, context: SampleContext): + pass + + definition = pool._context.tasks["strict_task"] + task = ax.Task( + task_name="strict_task", + context={"wrong_field": "oops"}, # missing required 'message' + agent_id=uuid.uuid4(), + ) + + from pydantic import ValidationError + with pytest.raises(ValidationError): + await definition.execute(task) diff --git a/tests/test_task_locking.py b/tests/test_task_locking.py index ab1f853..d71e9a5 100644 --- a/tests/test_task_locking.py +++ b/tests/test_task_locking.py @@ -1,16 +1,11 @@ -"""Tests for task-level distributed locking.""" - -import json import uuid import pytest from fakeredis import aioredis as fake_aioredis from pydantic import BaseModel -from unittest.mock import AsyncMock, patch import agentexec as ax -from agentexec import state -from agentexec.core.queue import requeue +from agentexec.state import backend from agentexec.core.task import TaskDefinition @@ -36,20 +31,10 @@ def pool(): @pytest.fixture def fake_redis(monkeypatch): - """Setup fake redis for state backend with shared state.""" - import fakeredis - - server = fakeredis.FakeServer() - fake_redis_sync = fakeredis.FakeRedis(server=server, decode_responses=False) - fake_redis_async = fake_aioredis.FakeRedis(server=server, decode_responses=False) - - monkeypatch.setattr("agentexec.state.redis_backend._get_sync_client", lambda: fake_redis_sync) - monkeypatch.setattr("agentexec.state.redis_backend._get_async_client", lambda: fake_redis_async) - - yield fake_redis_async - - -# --- TaskDefinition lock_key --- + """Setup fake redis for state backend.""" + fake = fake_aioredis.FakeRedis(decode_responses=False) + monkeypatch.setattr(backend, "_client", fake) + yield fake def test_task_definition_lock_key_default(): @@ -72,9 +57,6 @@ async def handler(agent_id: uuid.UUID, context: UserContext) -> TaskResult: assert defn.lock_key == "user:{user_id}" -# --- Pool registration with lock_key --- - - def test_pool_task_decorator_with_lock_key(pool): """@pool.task() passes lock_key to TaskDefinition.""" @@ -109,130 +91,69 @@ async def handler(agent_id: uuid.UUID, context: UserContext) -> TaskResult: assert defn.lock_key == "user:{user_id}" -# --- Task.get_lock_key() --- - - def test_get_lock_key_evaluates_template(pool): - """get_lock_key() evaluates template against context fields.""" + """definition.get_lock_key() evaluates template against context.""" @pool.task("locked_task", lock_key="user:{user_id}") async def handler(agent_id: uuid.UUID, context: UserContext) -> TaskResult: return TaskResult(status="ok") - defn = pool._context.tasks["locked_task"] - task = ax.Task.from_serialized( - defn, - { - "task_name": "locked_task", - "context": {"user_id": "42", "message": "hello"}, - "agent_id": str(uuid.uuid4()), - }, - ) - - assert task.get_lock_key() == "user:42" + definition = pool._context.tasks["locked_task"] + assert definition.get_lock_key({"user_id": "42", "message": "hello"}) == "user:42" def test_get_lock_key_returns_none_when_no_lock(pool): - """get_lock_key() returns None when no lock_key configured.""" + """definition.get_lock_key() returns None when no lock_key configured.""" @pool.task("unlocked_task") async def handler(agent_id: uuid.UUID, context: UserContext) -> TaskResult: return TaskResult(status="ok") - defn = pool._context.tasks["unlocked_task"] - task = ax.Task.from_serialized( - defn, - { - "task_name": "unlocked_task", - "context": {"user_id": "42"}, - "agent_id": str(uuid.uuid4()), - }, - ) - - assert task.get_lock_key() is None - - -def test_get_lock_key_raises_without_definition(): - """get_lock_key() raises RuntimeError if task not bound to definition.""" - task = ax.Task( - task_name="test", - context=UserContext(user_id="42"), - agent_id=uuid.uuid4(), - ) - - with pytest.raises(RuntimeError, match="must be bound to a definition"): - task.get_lock_key() + definition = pool._context.tasks["unlocked_task"] + assert definition.get_lock_key({"user_id": "42"}) is None def test_get_lock_key_raises_on_missing_field(pool): - """get_lock_key() raises KeyError if template references missing field.""" + """definition.get_lock_key() raises KeyError if template references missing field.""" @pool.task("bad_template", lock_key="org:{organization_id}") async def handler(agent_id: uuid.UUID, context: UserContext) -> TaskResult: return TaskResult(status="ok") - defn = pool._context.tasks["bad_template"] - task = ax.Task.from_serialized( - defn, - { - "task_name": "bad_template", - "context": {"user_id": "42"}, - "agent_id": str(uuid.uuid4()), - }, - ) - + definition = pool._context.tasks["bad_template"] with pytest.raises(KeyError): - task.get_lock_key() - - -# --- Redis lock acquire/release --- + definition.get_lock_key({"user_id": "42"}) async def test_acquire_lock_success(fake_redis): - """acquire_lock returns True when lock is free.""" - result = await state.acquire_lock("user:42", "agent-1") + """Queue backend acquire_lock returns True when lock is free.""" + queue_key = backend.queue._queue_key("user:42").encode() + result = await backend.queue._acquire_lock(queue_key) assert result is True async def test_acquire_lock_already_held(fake_redis): - """acquire_lock returns False when lock is already held.""" - await state.acquire_lock("user:42", "agent-1") - result = await state.acquire_lock("user:42", "agent-2") + """Queue backend acquire_lock returns False when already held.""" + queue_key = backend.queue._queue_key("user:42").encode() + await backend.queue._acquire_lock(queue_key) + result = await backend.queue._acquire_lock(queue_key) assert result is False async def test_release_lock(fake_redis): """release_lock frees the lock so it can be re-acquired.""" - await state.acquire_lock("user:42", "agent-1") - await state.release_lock("user:42") + queue_key = backend.queue._queue_key("user:42").encode() + await backend.queue._acquire_lock(queue_key) + await backend.queue.complete("user:42") - result = await state.acquire_lock("user:42", "agent-2") + result = await backend.queue._acquire_lock(queue_key) assert result is True -async def test_release_lock_nonexistent(fake_redis): - """release_lock on a non-existent key returns 0.""" - result = await state.release_lock("nonexistent") - assert result == 0 - - -async def test_lock_key_uses_prefix(fake_redis): - """Lock keys are prefixed with agentexec:lock:.""" - await state.acquire_lock("user:42", "agent-1") - - # Check the raw Redis key - value = await fake_redis.get("agentexec:lock:user:42") - assert value is not None - assert value.decode() == "agent-1" - - -# --- Requeue --- - - async def test_requeue_pushes_to_back(fake_redis, monkeypatch): """requeue() pushes task to the back of the queue (lpush).""" - def mock_create(*args, **kwargs): + async def mock_create(*args, **kwargs): return uuid.uuid4() monkeypatch.setattr("agentexec.core.task.activity.create", mock_create) @@ -240,21 +161,18 @@ def mock_create(*args, **kwargs): # Enqueue a normal task first task1 = await ax.enqueue("task_1", UserContext(user_id="1", message="first")) - # Create and requeue a second task + # Push a second task directly (simulating a requeue) task2 = ax.Task( task_name="task_2", - context=UserContext(user_id="2", message="requeued"), + context={"user_id": "2", "message": "requeued"}, agent_id=uuid.uuid4(), ) - requeue(task2) - - # Dequeue should return task_1 first (from front/right), then task_2 (from back/left) - from agentexec.core.queue import dequeue + await backend.queue.push(task2.model_dump_json()) - result1 = await dequeue(timeout=1) + result1 = await backend.queue.pop(timeout=1) assert result1 is not None assert result1["task_name"] == "task_1" - result2 = await dequeue(timeout=1) + result2 = await backend.queue.pop(timeout=1) assert result2 is not None assert result2["task_name"] == "task_2" diff --git a/tests/test_worker_event.py b/tests/test_worker_event.py index 950bc06..4e83eb4 100644 --- a/tests/test_worker_event.py +++ b/tests/test_worker_event.py @@ -1,133 +1,71 @@ -"""Test state-backed event for cross-process coordination.""" - import pytest from fakeredis import aioredis as fake_aioredis -import fakeredis +from agentexec.state import backend from agentexec.worker.event import StateEvent @pytest.fixture -def fake_redis_sync(monkeypatch): - """Setup fake sync redis for state backend.""" - fake_redis = fakeredis.FakeRedis(decode_responses=False) - - def get_fake_sync_client(): - return fake_redis - - monkeypatch.setattr("agentexec.state.redis_backend._get_sync_client", get_fake_sync_client) - - yield fake_redis - - -@pytest.fixture -def fake_redis_async(monkeypatch): - """Setup fake async redis for state backend.""" - fake_redis = fake_aioredis.FakeRedis(decode_responses=False) - - def get_fake_async_client(): - return fake_redis - - monkeypatch.setattr("agentexec.state.redis_backend._get_async_client", get_fake_async_client) - - yield fake_redis +def fake_redis(monkeypatch): + """Inject fake redis into the backend.""" + fake = fake_aioredis.FakeRedis(decode_responses=False) + monkeypatch.setattr(backend, "_client", fake) + yield fake def test_state_event_initialization(): - """Test StateEvent can be initialized with name and id.""" event = StateEvent("test", "event123") - assert event.name == "test" assert event.id == "event123" -def test_redis_event_set(fake_redis_sync): - """Test StateEvent.set() sets the key in Redis.""" +async def test_redis_event_set(fake_redis): event = StateEvent("shutdown", "pool1") - - event.set() - - # Verify the key was set (with event prefix and formatted name:id) - value = fake_redis_sync.get("agentexec:event:shutdown:pool1") + await event.set() + value = await fake_redis.get("agentexec:event:shutdown:pool1") assert value == b"1" -def test_redis_event_clear(fake_redis_sync): - """Test StateEvent.clear() removes the key from Redis.""" +async def test_redis_event_clear(fake_redis): event = StateEvent("shutdown", "pool2") - - # Set then clear - fake_redis_sync.set("agentexec:event:shutdown:pool2", "1") - event.clear() - - # Verify the key was removed - value = fake_redis_sync.get("agentexec:event:shutdown:pool2") + await fake_redis.set("agentexec:event:shutdown:pool2", "1") + await event.clear() + value = await fake_redis.get("agentexec:event:shutdown:pool2") assert value is None -def test_redis_event_clear_nonexistent(fake_redis_sync): - """Test StateEvent.clear() handles non-existent keys gracefully.""" +async def test_redis_event_clear_nonexistent(fake_redis): event = StateEvent("nonexistent", "id123") - - # Should not raise an error - event.clear() + await event.clear() -async def test_redis_event_is_set_true(fake_redis_async): - """Test StateEvent.is_set() returns True when key exists.""" +async def test_redis_event_is_set_true(fake_redis): event = StateEvent("shutdown", "pool3") + await fake_redis.set("agentexec:event:shutdown:pool3", "1") + assert await event.is_set() is True - # Set the key - await fake_redis_async.set("agentexec:event:shutdown:pool3", "1") - # Check is_set - result = await event.is_set() - assert result is True - - -async def test_redis_event_is_set_false(fake_redis_async): - """Test StateEvent.is_set() returns False when key doesn't exist.""" +async def test_redis_event_is_set_false(fake_redis): event = StateEvent("shutdown", "pool4") - - # Don't set the key - result = await event.is_set() - assert result is False + assert await event.is_set() is False -async def test_redis_event_is_set_after_clear(fake_redis_sync, fake_redis_async): - """Test StateEvent.is_set() returns False after clear().""" +async def test_redis_event_is_set_after_clear(fake_redis): event = StateEvent("shutdown", "pool5") - - # Set then clear - event.set() - event.clear() - - # Check is_set - result = await event.is_set() - assert result is False + await event.set() + await event.clear() + assert await event.is_set() is False def test_redis_event_picklable(): - """Test StateEvent is picklable (for multiprocessing).""" import pickle - event = StateEvent("shutdown", "pickle123") - - # Pickle and unpickle - pickled = pickle.dumps(event) - unpickled = pickle.loads(pickled) - + unpickled = pickle.loads(pickle.dumps(event)) assert unpickled.name == "shutdown" assert unpickled.id == "pickle123" def test_redis_event_multiple_events(): - """Test multiple StateEvent instances with different names.""" event1 = StateEvent("event", "id1") event2 = StateEvent("event", "id2") - assert event1.id != event2.id - assert event1.name == "event" - assert event2.name == "event" - assert event1.id == "id1" - assert event2.id == "id2" diff --git a/tests/test_worker_logging.py b/tests/test_worker_logging.py index dc9662e..be6489b 100644 --- a/tests/test_worker_logging.py +++ b/tests/test_worker_logging.py @@ -1,17 +1,15 @@ -"""Test worker logging functionality.""" - import logging import time import pytest -import fakeredis +from fakeredis import aioredis as fake_aioredis from agentexec.worker.logging import ( DEFAULT_FORMAT, LOG_CHANNEL, LOGGER_NAME, LogMessage, - StateLogHandler, + QueueLogHandler, get_worker_logger, ) @@ -135,44 +133,25 @@ def test_log_message_with_none_values(self): assert log_message.thread is None -class TestStateLogHandler: - """Tests for StateLogHandler.""" +class TestQueueLogHandler: + """Tests for QueueLogHandler.""" - @pytest.fixture - def fake_redis_backend(self, monkeypatch): - """Setup fake redis backend for state.""" - fake_redis = fakeredis.FakeRedis(decode_responses=False) + def test_handler_initialization(self): + """Test QueueLogHandler initializes with a queue.""" + import multiprocessing as mp + tx = mp.Queue() + handler = QueueLogHandler(tx) + assert handler.tx is tx - def get_fake_sync_client(): - return fake_redis + def test_handler_emit(self): + """Test QueueLogHandler.emit() puts LogEntry on the queue.""" + import multiprocessing as mp + import time + from agentexec.worker.pool import LogEntry - monkeypatch.setattr( - "agentexec.state.redis_backend._get_sync_client", get_fake_sync_client - ) + tx = mp.Queue() + handler = QueueLogHandler(tx) - return fake_redis - - def test_handler_initialization(self): - """Test StateLogHandler initializes with default channel.""" - handler = StateLogHandler() - assert handler.channel == LOG_CHANNEL - - def test_handler_custom_channel(self): - """Test StateLogHandler with custom channel.""" - handler = StateLogHandler(channel="custom:logs") - assert handler.channel == "custom:logs" - - def test_handler_emit(self, fake_redis_backend): - """Test StateLogHandler.emit() publishes to state backend.""" - handler = StateLogHandler() - - # Subscribe to the channel to capture the message - pubsub = fake_redis_backend.pubsub() - pubsub.subscribe(LOG_CHANNEL) - # Get the subscribe message - pubsub.get_message() - - # Create and emit a log record record = logging.LogRecord( name="emit.test", level=logging.INFO, @@ -184,18 +163,12 @@ def test_handler_emit(self, fake_redis_backend): ) handler.emit(record) + time.sleep(0.1) # mp.Queue uses a background thread to flush - # Get the published message - message = pubsub.get_message() - - assert message is not None - assert message["type"] == "message" - assert message["channel"] == LOG_CHANNEL.encode() - - # Verify the message content - log_message = LogMessage.model_validate_json(message["data"]) - assert log_message.msg == "Emitted message" - assert log_message.levelno == logging.INFO + message = tx.get_nowait() + assert isinstance(message, LogEntry) + assert message.record.msg == "Emitted message" + assert message.record.levelno == logging.INFO class TestGetWorkerLogger: @@ -204,18 +177,10 @@ class TestGetWorkerLogger: @pytest.fixture(autouse=True) def reset_logging_state(self, monkeypatch): """Reset the worker logging configured state before each test.""" - # Reset the global state monkeypatch.setattr("agentexec.worker.logging._worker_logging_configured", False) - # Setup fake redis backend - fake_redis = fakeredis.FakeRedis(decode_responses=False) - monkeypatch.setattr( - "agentexec.state.redis_backend._get_sync_client", lambda: fake_redis - ) - yield - # Cleanup handlers added during tests root = logging.getLogger(LOGGER_NAME) root.handlers.clear() @@ -239,18 +204,21 @@ def test_get_worker_logger_existing_namespace(self): assert logger.name == f"{LOGGER_NAME}.submodule" def test_get_worker_logger_configures_handler(self): - """Test get_worker_logger adds StateLogHandler on first call.""" - logger = get_worker_logger("first.call") + """Test get_worker_logger adds QueueLogHandler on first call.""" + import multiprocessing as mp + tx = mp.Queue() + get_worker_logger("first.call", tx=tx) root = logging.getLogger(LOGGER_NAME) handler_types = [type(h).__name__ for h in root.handlers] - assert "StateLogHandler" in handler_types + assert "QueueLogHandler" in handler_types def test_get_worker_logger_idempotent(self): """Test get_worker_logger only configures once.""" - # First call - get_worker_logger("first") + import multiprocessing as mp + tx = mp.Queue() + get_worker_logger("first", tx=tx) root = logging.getLogger(LOGGER_NAME) initial_handler_count = len(root.handlers) diff --git a/tests/test_worker_pool.py b/tests/test_worker_pool.py index 20f5bc1..b0b2ced 100644 --- a/tests/test_worker_pool.py +++ b/tests/test_worker_pool.py @@ -1,12 +1,13 @@ -"""Test Pool implementation.""" - import json +import multiprocessing as mp import uuid +from unittest.mock import AsyncMock import pytest from pydantic import BaseModel import agentexec as ax +from agentexec.state import backend class SampleContext(BaseModel): @@ -24,22 +25,19 @@ class TaskResult(BaseModel): @pytest.fixture def mock_state_backend(monkeypatch): - """Mock the state backend for queue operations.""" + """Mock the queue ops for push operations.""" queue_data = [] - def mock_lpush(key, value): - queue_data.insert(0, value) - return len(queue_data) - - def mock_rpush(key, value): - queue_data.append(value) - return len(queue_data) + async def mock_queue_push(value, *, high_priority=False, partition_key=None): + if high_priority: + queue_data.append(value) + else: + queue_data.insert(0, value) def pop_right(): return queue_data.pop() if queue_data else None - monkeypatch.setattr("agentexec.state.backend.lpush", mock_lpush) - monkeypatch.setattr("agentexec.state.backend.rpush", mock_rpush) + monkeypatch.setattr("agentexec.state.backend.queue.push", mock_queue_push) return {"queue": queue_data, "pop": pop_right} @@ -55,8 +53,7 @@ def pool(): async def test_enqueue_task(mock_state_backend, pool, monkeypatch) -> None: """Test that tasks can be enqueued.""" - # Mock activity.create to avoid database dependency - def mock_create(*args, **kwargs): + async def mock_create(*args, **kwargs): return uuid.uuid4() monkeypatch.setattr("agentexec.core.task.activity.create", mock_create) @@ -74,8 +71,7 @@ async def handler(agent_id: uuid.UUID, context: SampleContext) -> TaskResult: assert task is not None assert isinstance(task.agent_id, uuid.UUID) assert task.task_name == "test_task" - assert isinstance(task.context, SampleContext) - assert task.context.message == "Hello World" + assert task.context["message"] == "Hello World" # Verify task was pushed to queue task_json = mock_state_backend["pop"]() @@ -89,7 +85,7 @@ async def handler(agent_id: uuid.UUID, context: SampleContext) -> TaskResult: async def test_enqueue_high_priority_task(mock_state_backend, pool, monkeypatch) -> None: """Test that high priority tasks are enqueued to the front.""" - def mock_create(*args, **kwargs): + async def mock_create(*args, **kwargs): return uuid.uuid4() monkeypatch.setattr("agentexec.core.task.activity.create", mock_create) @@ -111,7 +107,7 @@ async def high_handler(agent_id: uuid.UUID, context: SampleContext) -> TaskResul ctx2 = SampleContext(message="high", value=2) task2 = await ax.enqueue("high_task", ctx2, priority=ax.Priority.HIGH) - # High priority task should be at the end (RPUSH) so it's processed first (BRPOP) + # High priority task should be at the end (popped first) task_json = mock_state_backend["pop"]() task_data = json.loads(task_json) assert task_data["agent_id"] == str(task2.agent_id) @@ -119,7 +115,7 @@ async def high_handler(agent_id: uuid.UUID, context: SampleContext) -> TaskResul async def test_add_task_registers_handler(mock_state_backend, pool, monkeypatch) -> None: """Test that pool.add_task() registers a task handler.""" - def mock_create(*args, **kwargs): + async def mock_create(*args, **kwargs): return uuid.uuid4() monkeypatch.setattr("agentexec.core.task.activity.create", mock_create) @@ -139,8 +135,7 @@ async def handler(*, agent_id: uuid.UUID, context: SampleContext) -> TaskResult: assert task is not None assert task.task_name == "added_task" - assert isinstance(task.context, SampleContext) - assert task.context.message == "Added via add_task" + assert task.context["message"] == "Added via add_task" def test_add_task_duplicate_raises(pool) -> None: @@ -182,19 +177,10 @@ def test_pool_with_database_url() -> None: """Test that Pool can be created with database_url.""" pool = ax.Pool(database_url="sqlite:///:memory:") - assert pool._context.database_url == "sqlite:///:memory:" + assert pool._processes == [] assert pool._processes == [] -def test_pool_with_custom_queue_name() -> None: - """Test that Pool can use a custom queue name.""" - pool = ax.Pool( - database_url="sqlite:///:memory:", - queue_name="custom_queue", - ) - - assert pool._context.queue_name == "custom_queue" - async def test_worker_dequeue_task(pool, monkeypatch) -> None: """Test Worker._dequeue_task method.""" @@ -206,15 +192,12 @@ async def handler(agent_id: uuid.UUID, context: SampleContext) -> TaskResult: return TaskResult() context = WorkerContext( - database_url="sqlite:///:memory:", shutdown_event=StateEvent("shutdown", "test-worker"), tasks=pool._context.tasks, - queue_name="test_queue", + tx=mp.Queue(), ) - worker = Worker(worker_id=0, context=context) - - # Mock dequeue to return task data + # Mock queue_pop to return task data agent_id = uuid.uuid4() task_data = { "task_name": "test_task", @@ -222,53 +205,40 @@ async def handler(agent_id: uuid.UUID, context: SampleContext) -> TaskResult: "agent_id": str(agent_id), } - async def mock_dequeue(**kwargs): + async def mock_queue_pop(*args, **kwargs): return task_data - monkeypatch.setattr("agentexec.worker.pool.dequeue", mock_dequeue) + monkeypatch.setattr("agentexec.state.backend.queue.pop", mock_queue_pop) - task = await worker._dequeue_task() + data = await backend.queue.pop(timeout=1) + assert data is not None - assert task is not None + task = ax.Task.model_validate(data) assert task.task_name == "test_task" - assert isinstance(task.context, SampleContext) - assert task.context.message == "test" + assert task.context == {"message": "test", "value": 42} assert task.agent_id == agent_id -async def test_worker_dequeue_task_returns_none_on_empty_queue(pool, monkeypatch) -> None: - """Test Worker._dequeue_task returns None when queue is empty.""" - from agentexec.worker.pool import Worker, WorkerContext - from agentexec.worker.event import StateEvent - - context = WorkerContext( - database_url="sqlite:///:memory:", - shutdown_event=StateEvent("shutdown", "test-worker"), - tasks=pool._context.tasks, - queue_name="test_queue", - ) - - worker = Worker(worker_id=0, context=context) +async def test_dequeue_returns_none_on_empty_queue(pool, monkeypatch) -> None: + """Test pop returns None when queue is empty.""" - async def mock_dequeue(**kwargs): + async def mock_queue_pop(*args, **kwargs): return None - monkeypatch.setattr("agentexec.worker.pool.dequeue", mock_dequeue) - - task = await worker._dequeue_task() + monkeypatch.setattr("agentexec.state.backend.queue.pop", mock_queue_pop) - assert task is None + data = await backend.queue.pop(timeout=1) + assert data is None -def test_worker_pool_shutdown_with_no_processes(pool, monkeypatch) -> None: +async def test_worker_pool_shutdown_with_no_processes(pool) -> None: """Test shutdown when no processes have been started.""" - # Mock the shutdown event to avoid Redis dependency - from unittest.mock import MagicMock + from unittest.mock import AsyncMock - pool._context.shutdown_event = MagicMock() + pool._context.shutdown_event = AsyncMock() # Should not raise even with empty process list - pool.shutdown(timeout=1) + await pool.shutdown(timeout=1) assert pool._processes == [] pool._context.shutdown_event.set.assert_called_once() @@ -282,3 +252,285 @@ def test_get_pool_id() -> None: id2 = _get_pool_id() assert id1 != id2 + + +class TestTaskFailed: + def test_from_exception(self): + """TaskFailed.from_exception captures the error string.""" + from agentexec.worker.pool import TaskFailed + + task = ax.Task( + task_name="test_task", + context={"message": "hello"}, + agent_id=uuid.uuid4(), + ) + exc = RuntimeError("something broke") + msg = TaskFailed.from_exception(task, exc) + + assert msg.task == task + assert msg.error == "something broke" + + def test_preserves_retry_count(self): + """TaskFailed preserves the task's current retry_count.""" + from agentexec.worker.pool import TaskFailed + + task = ax.Task( + task_name="test_task", + context={"message": "hello"}, + agent_id=uuid.uuid4(), + retry_count=2, + ) + msg = TaskFailed.from_exception(task, ValueError("bad")) + assert msg.task.retry_count == 2 + + +class TestWorkerFailurePath: + """Test that Worker._run sends TaskFailed on handler exception.""" + + async def test_exception_sends_task_failed(self, pool, monkeypatch): + """Handler exception → TaskFailed sent via IPC queue.""" + from agentexec.worker.pool import Worker, WorkerContext, TaskFailed + + @pool.task("failing_task") + async def handler(agent_id: uuid.UUID, context: SampleContext): + raise RuntimeError("handler exploded") + + tx = mp.Queue() + call_count = 0 + shutdown = AsyncMock() + + async def is_set(): + nonlocal call_count + return call_count > 1 + + shutdown.is_set = is_set + + context = WorkerContext( + shutdown_event=shutdown, + tasks=pool._context.tasks, + tx=tx, + ) + + async def mock_pop(*, timeout=1): + nonlocal call_count + call_count += 1 + if call_count == 1: + return { + "task_name": "failing_task", + "context": {"message": "boom"}, + "agent_id": str(uuid.uuid4()), + } + return None + + import agentexec.activity as activity_mod + monkeypatch.setattr(activity_mod, "update", AsyncMock()) + monkeypatch.setattr("agentexec.state.backend.queue.pop", mock_pop) + monkeypatch.setattr("agentexec.state.backend.queue.complete", AsyncMock()) + + worker = Worker(0, context) + + # Capture _send calls directly to avoid mp.Queue reliability issues + sent_messages = [] + original_send = worker._send + def capture_send(message): + sent_messages.append(message) + original_send(message) + monkeypatch.setattr(worker, "_send", capture_send) + + await worker._run() + + failed = [m for m in sent_messages if isinstance(m, TaskFailed)] + assert len(failed) == 1 + assert failed[0].error == "handler exploded" + assert failed[0].task.task_name == "failing_task" + + async def test_complete_called_after_failure(self, pool, monkeypatch): + """queue.complete is called even when the handler throws.""" + from agentexec.worker.pool import Worker, WorkerContext + + @pool.task("locked_fail") + async def handler(agent_id: uuid.UUID, context: SampleContext): + raise ValueError("oops") + + pool._context.tasks["locked_fail"].lock_key = "msg:{message}" + + tx = mp.Queue() + call_count = 0 + shutdown = AsyncMock() + + async def is_set(): + nonlocal call_count + return call_count > 1 + + shutdown.is_set = is_set + + context = WorkerContext( + shutdown_event=shutdown, + tasks=pool._context.tasks, + tx=tx, + ) + + async def mock_pop(*, timeout=1): + nonlocal call_count + call_count += 1 + if call_count == 1: + return { + "task_name": "locked_fail", + "context": {"message": "test"}, + "agent_id": str(uuid.uuid4()), + } + return None + + completed_keys = [] + + async def mock_complete(partition_key): + completed_keys.append(partition_key) + + import agentexec.activity as activity_mod + monkeypatch.setattr(activity_mod, "update", AsyncMock()) + monkeypatch.setattr("agentexec.state.backend.queue.pop", mock_pop) + monkeypatch.setattr("agentexec.state.backend.queue.complete", mock_complete) + + worker = Worker(0, context) + await worker._run() + + assert completed_keys == ["msg:test"] + + +class TestPoolRetryLogic: + """Test that Pool._process_worker_events handles TaskFailed correctly.""" + + async def test_requeues_with_incremented_retry(self, pool, monkeypatch): + """Failed task with retries remaining is requeued as high priority.""" + from agentexec.worker.pool import TaskFailed + + @pool.task("retry_task") + async def handler(agent_id: uuid.UUID, context: SampleContext): + pass + + task = ax.Task( + task_name="retry_task", + context={"message": "test"}, + agent_id=uuid.uuid4(), + retry_count=0, + ) + + pushed = [] + + async def mock_push(value, *, high_priority=False, partition_key=None): + pushed.append({"value": value, "high_priority": high_priority, "partition_key": partition_key}) + + monkeypatch.setattr("agentexec.state.backend.queue.push", mock_push) + monkeypatch.setattr(ax.CONF, "max_task_retries", 3) + + # Put a TaskFailed message in the worker queue + pool._worker_queue.put_nowait(TaskFailed(task=task, error="boom")) + + # Simulate one iteration of _process_worker_events + # We need a fake process that reports alive once then dead + class FakeProcess: + def __init__(self): + self._calls = 0 + + def is_alive(self): + self._calls += 1 + return self._calls <= 2 # alive for first check, dead on second + + pool._processes = [FakeProcess()] + pool._log_handler = __import__("logging").StreamHandler() + + await pool._process_worker_events() + + assert len(pushed) == 1 + requeued = json.loads(pushed[0]["value"]) + assert requeued["retry_count"] == 1 + assert pushed[0]["high_priority"] is True + + async def test_gives_up_after_max_retries(self, pool, monkeypatch, capsys): + """Failed task at max retries is not requeued.""" + from agentexec.worker.pool import TaskFailed + + @pool.task("doomed_task") + async def handler(agent_id: uuid.UUID, context: SampleContext): + pass + + task = ax.Task( + task_name="doomed_task", + context={"message": "test"}, + agent_id=uuid.uuid4(), + retry_count=3, + ) + + pushed = [] + + async def mock_push(value, *, high_priority=False, partition_key=None): + pushed.append(value) + + monkeypatch.setattr("agentexec.state.backend.queue.push", mock_push) + monkeypatch.setattr(ax.CONF, "max_task_retries", 3) + + pool._worker_queue.put_nowait(TaskFailed(task=task, error="fatal")) + + class FakeProcess: + def __init__(self): + self._calls = 0 + + def is_alive(self): + self._calls += 1 + return self._calls <= 2 + + pool._processes = [FakeProcess()] + pool._log_handler = __import__("logging").StreamHandler() + + await pool._process_worker_events() + + # Should NOT have requeued + assert len(pushed) == 0 + + # Should have printed the give-up message + captured = capsys.readouterr() + assert "doomed_task" in captured.out + assert "4 attempts" in captured.out + assert "fatal" in captured.out + + async def test_retry_preserves_partition_key(self, pool, monkeypatch): + """Requeued task uses the correct partition key from its definition.""" + from agentexec.worker.pool import TaskFailed + + @pool.task("partitioned_task") + async def handler(agent_id: uuid.UUID, context: SampleContext): + pass + + pool._context.tasks["partitioned_task"].lock_key = "msg:{message}" + + task = ax.Task( + task_name="partitioned_task", + context={"message": "hello"}, + agent_id=uuid.uuid4(), + retry_count=0, + ) + + pushed = [] + + async def mock_push(value, *, high_priority=False, partition_key=None): + pushed.append({"partition_key": partition_key}) + + monkeypatch.setattr("agentexec.state.backend.queue.push", mock_push) + monkeypatch.setattr(ax.CONF, "max_task_retries", 3) + + pool._worker_queue.put_nowait(TaskFailed(task=task, error="transient")) + + class FakeProcess: + def __init__(self): + self._calls = 0 + + def is_alive(self): + self._calls += 1 + return self._calls <= 2 + + pool._processes = [FakeProcess()] + pool._log_handler = __import__("logging").StreamHandler() + + await pool._process_worker_events() + + assert pushed[0]["partition_key"] == "msg:hello" diff --git a/uv.lock b/uv.lock index cc95411..0e2ab2e 100644 --- a/uv.lock +++ b/uv.lock @@ -4,7 +4,7 @@ requires-python = ">=3.12" [[package]] name = "agentexec" -version = "0.1.6" +version = "0.1.7" source = { editable = "." } dependencies = [ { name = "croniter" }, @@ -15,6 +15,11 @@ dependencies = [ { name = "sqlalchemy" }, ] +[package.optional-dependencies] +kafka = [ + { name = "aiokafka" }, +] + [package.dev-dependencies] dev = [ { name = "fakeredis" }, @@ -28,6 +33,7 @@ dev = [ [package.metadata] requires-dist = [ + { name = "aiokafka", marker = "extra == 'kafka'", specifier = ">=0.11.0" }, { name = "croniter", specifier = ">=6.0.0" }, { name = "openai-agents", specifier = ">=0.1.0" }, { name = "pydantic", specifier = ">=2.12.0" }, @@ -35,6 +41,7 @@ requires-dist = [ { name = "redis", specifier = ">=7.0.1" }, { name = "sqlalchemy", specifier = ">=2.0.44" }, ] +provides-extras = ["kafka"] [package.metadata.requires-dev] dev = [ @@ -47,6 +54,37 @@ dev = [ { name = "ty", specifier = ">=0.0.1a7" }, ] +[[package]] +name = "aiokafka" +version = "0.13.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "async-timeout" }, + { name = "packaging" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/87/18/d3a4f8f9ad099fc59217b8cdf66eeecde3a9ef3bb31fe676e431a3b0010f/aiokafka-0.13.0.tar.gz", hash = "sha256:7d634af3c8d694a37a6c8535c54f01a740e74cccf7cc189ecc4a3d64e31ce122", size = 598580, upload-time = "2026-01-02T13:55:18.911Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/60/17/715ac23b4f8df3ff8d7c0a6f1c5fd3a179a8a675205be62d1d1bb27dffa2/aiokafka-0.13.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:231ecc0038c2736118f1c95149550dbbdf7b7a12069f70c005764fa1824c35d4", size = 346168, upload-time = "2026-01-02T13:54:49.128Z" }, + { url = "https://files.pythonhosted.org/packages/00/26/71c6f4cce2c710c6ffa18b9e294384157f46b0491d5b020de300802d167e/aiokafka-0.13.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2e2817593cab4c71c1d3b265b2446da91121a467ff7477c65f0f39a80047bc28", size = 349037, upload-time = "2026-01-02T13:54:50.48Z" }, + { url = "https://files.pythonhosted.org/packages/82/18/7b86418a4d3dc1303e89c0391942258ead31c02309e90eb631f3081eec1d/aiokafka-0.13.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b80e0aa1c811a9a12edb0b94445a0638d61a345932f785d47901d28b8aad86c8", size = 1140066, upload-time = "2026-01-02T13:54:52.33Z" }, + { url = "https://files.pythonhosted.org/packages/f9/51/45e46b4407d39b950c8493e19498aeeb5af4fc461fb54fa0247da16bfd75/aiokafka-0.13.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:79672c456bd1642769e74fc2db1c34f23b15500e978fd38411662e8ca07590ad", size = 1130088, upload-time = "2026-01-02T13:54:53.786Z" }, + { url = "https://files.pythonhosted.org/packages/49/7f/6a66f6fd6fb73e15bd34f574e38703ba36d3f9256c80e7aba007bd8a9256/aiokafka-0.13.0-cp312-cp312-win32.whl", hash = "sha256:00bb4e3d5a237b8618883eb1dd8c08d671db91d3e8e33ac98b04edf64225658c", size = 309581, upload-time = "2026-01-02T13:54:55.444Z" }, + { url = "https://files.pythonhosted.org/packages/d3/e0/a2d5a8912699dd0fee28e6fb780358c63c7a4727517fffc110cb7e43f874/aiokafka-0.13.0-cp312-cp312-win_amd64.whl", hash = "sha256:0f0cccdf2fd16927fbe077279524950676fbffa7b102d6b117041b3461b5d927", size = 329327, upload-time = "2026-01-02T13:54:56.981Z" }, + { url = "https://files.pythonhosted.org/packages/e3/f6/a74c49759233e98b61182ba3d49d5ac9c8de0643651892acba2704fba1cc/aiokafka-0.13.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:39d71c40cff733221a6b2afff4beeac5dacbd119fb99eec5198af59115264a1a", size = 343733, upload-time = "2026-01-02T13:54:58.536Z" }, + { url = "https://files.pythonhosted.org/packages/cf/52/4f7e80eee2c69cd8b047c18145469bf0dc27542a5dca3f96ff81ade575b0/aiokafka-0.13.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:faa2f5f3d0d2283a0c1a149748cc7e3a3862ef327fa5762e2461088eedde230a", size = 346258, upload-time = "2026-01-02T13:55:00.947Z" }, + { url = "https://files.pythonhosted.org/packages/81/9b/d2766bb3b0bad53eb25a88e51a884be4b77a1706053ad717b893b4daea4b/aiokafka-0.13.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b890d535e55f5073f939585bef5301634df669e97832fda77aa743498f008662", size = 1114744, upload-time = "2026-01-02T13:55:02.475Z" }, + { url = "https://files.pythonhosted.org/packages/8f/00/12e0a39cd4809149a09b4a52b629abc9bf80e7b8bad9950040b1adae99fc/aiokafka-0.13.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e22eb8a1475b9c0f45b553b6e2dcaf4ec3c0014bf4e389e00a0a0ec85d0e3bdc", size = 1105676, upload-time = "2026-01-02T13:55:04.036Z" }, + { url = "https://files.pythonhosted.org/packages/38/4a/0bc91e90faf55533fe6468461c2dd31c22b0e1d274b9386f341cca3f7eb7/aiokafka-0.13.0-cp313-cp313-win32.whl", hash = "sha256:ae507c7b09e882484f709f2e7172b3a4f75afffcd896d00517feb35c619495bb", size = 308257, upload-time = "2026-01-02T13:55:05.873Z" }, + { url = "https://files.pythonhosted.org/packages/23/63/5433d1aa10c4fb4cf85bd73013263c36d7da4604b0c77ed4d1ad42fae70c/aiokafka-0.13.0-cp313-cp313-win_amd64.whl", hash = "sha256:fec1a7e3458365a72809edaa2b990f65ca39b01a2a579f879ac4da6c9b2dbc5c", size = 326968, upload-time = "2026-01-02T13:55:07.351Z" }, + { url = "https://files.pythonhosted.org/packages/3c/cc/45b04c3a5fd3d2d5f444889ecceb80b2f78d6d66aa45e3042767e55579e2/aiokafka-0.13.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:9a403785f7092c72906c37f7618f7b16a4219eba8ed0bdda90fba410a7dd50b5", size = 344503, upload-time = "2026-01-02T13:55:08.723Z" }, + { url = "https://files.pythonhosted.org/packages/76/df/0b76fe3b93558ae71b856940e384909c4c2c7a1c330423003191e4ba7782/aiokafka-0.13.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:256807326831b7eee253ea1017bd2b19ab1c2298ce6b20a87fde97c253c572bc", size = 347621, upload-time = "2026-01-02T13:55:10.147Z" }, + { url = "https://files.pythonhosted.org/packages/34/1a/d59932f98fd3c106e2a7c8d4d5ebd8df25403436dfc27b3031918a37385e/aiokafka-0.13.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:64d90f91291da265d7f25296ba68fc6275684eebd6d1cf05a1b2abe6c2ba3543", size = 1111410, upload-time = "2026-01-02T13:55:11.763Z" }, + { url = "https://files.pythonhosted.org/packages/7e/04/fbf3e34ab3bc21e6e760c3fcd089375052fccc04eb8745459a82a58a647b/aiokafka-0.13.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b5a33cc043c8d199bcf101359d86f2d31fd54f4b157ac12028bdc34e3e1cf74a", size = 1094799, upload-time = "2026-01-02T13:55:13.795Z" }, + { url = "https://files.pythonhosted.org/packages/85/10/509f709fd3b7c3e568a5b8044be0e80a1504f8da6ddc72c128b21e270913/aiokafka-0.13.0-cp314-cp314-win32.whl", hash = "sha256:538950384b539ba2333d35a853f09214c0409e818e5d5f366ef759eea50bae9c", size = 311553, upload-time = "2026-01-02T13:55:15.928Z" }, + { url = "https://files.pythonhosted.org/packages/2b/18/424d6a4eb6f4835a371c1e2cfafce800540b33d957c6638795d911f98973/aiokafka-0.13.0-cp314-cp314-win_amd64.whl", hash = "sha256:c906dd42daadd14b4506a2e6c62dfef3d4919b5953d32ae5e5f0d99efd103c89", size = 330648, upload-time = "2026-01-02T13:55:17.421Z" }, +] + [[package]] name = "annotated-types" version = "0.7.0" @@ -70,6 +108,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/15/b3/9b1a8074496371342ec1e796a96f99c82c945a339cd81a8e73de28b4cf9e/anyio-4.11.0-py3-none-any.whl", hash = "sha256:0287e96f4d26d4149305414d4e3bc32f0dcd0862365a4bddea19d7a1ec38c4fc", size = 109097, upload-time = "2025-09-23T09:19:10.601Z" }, ] +[[package]] +name = "async-timeout" +version = "5.0.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a5/ae/136395dfbfe00dfc94da3f3e136d0b13f394cba8f4841120e34226265780/async_timeout-5.0.1.tar.gz", hash = "sha256:d9321a7a3d5a6a5e187e824d2fa0793ce379a202935782d555d6e9d2735677d3", size = 9274, upload-time = "2024-11-06T16:41:39.6Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fe/ba/e2081de779ca30d473f21f5b30e0e737c438205440784c7dfc81efc2b029/async_timeout-5.0.1-py3-none-any.whl", hash = "sha256:39e3809566ff85354557ec2398b55e096c8364bacac9405a7a1fa429e77fe76c", size = 6233, upload-time = "2024-11-06T16:41:37.9Z" }, +] + [[package]] name = "attrs" version = "25.4.0"