Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions reboot/examples/agent-wiki/api/agent_wiki/v1/wiki.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,14 @@ class WikiState(Model):
# entries and folds each transcript's material into the
# wiki's markdown `content` and into Pages it references.
transcripts: dict[str, bool] = Field(tag=5, default_factory=dict)
# The user ID of the user who owns this Wiki.
owner_id: str = Field(tag=6)


class WikiCreateRequest(Model):
name: str = Field(tag=1)
description: str = Field(tag=2)
owner_id: str = Field(tag=3)


class WikiGetResponse(Model):
Expand Down Expand Up @@ -113,11 +116,14 @@ class PageState(Model):
# link, call the corresponding type with the referenced
# state ID (e.g., `Page.get` on `abc123`).
content: str = Field(tag=2, default="")
# The user ID of the user who owns this Page.
owner_id: str = Field(tag=3)


class PageCreateRequest(Model):
title: str = Field(tag=1)
content: str = Field(tag=2)
owner_id: str = Field(tag=3)


class PageGetResponse(Model):
Expand All @@ -135,10 +141,13 @@ class PageUpdateRequest(Model):

class TranscriptState(Model):
messages: list[TranscriptMessage] = Field(tag=1, default_factory=list)
# The user ID of the user who owns this Transcript.
owner_id: str = Field(tag=2)


class TranscriptCreateRequest(Model):
messages: list[TranscriptMessage] = Field(tag=1, default_factory=list)
owner_id: str = Field(tag=2)


class TranscriptGetResponse(Model):
Expand Down
44 changes: 43 additions & 1 deletion reboot/examples/agent-wiki/backend/src/servicers/wiki.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,35 @@
from agent_wiki.v1.wiki_rbt import Page, Transcript, User, Wiki
from dataclasses import dataclass
from pydantic_ai import RunContext
from rbt.v1alpha1.errors_pb2 import Ok, PermissionDenied, Unauthenticated
from reboot.agents.pydantic_ai import Agent
from reboot.aio.auth.authorizers import allow_if, is_app_internal
from reboot.aio.contexts import (
ReaderContext,
TransactionContext,
WorkflowContext,
WriterContext,
)
from reboot.aio.workflows import at_most_once, until
from typing import Union

logger = logging.getLogger(__name__)


def _caller_is_owner(
*,
context: ReaderContext,
state: Union[Wiki.State, Transcript.State, Page.State],
**kwargs,
):
"""Allow when the caller's `user_id` matches the stored owner ID."""
if context.auth is None or context.auth.user_id is None:
return Unauthenticated()
if state is not None and context.auth.user_id == state.owner_id:
return Ok()
return PermissionDenied()


def _truncate(value: object, limit: int = 500) -> str:
"""Render `value` as a string, shortened for log lines so a
big tool payload doesn't flood the output."""
Expand Down Expand Up @@ -79,6 +96,7 @@ def _log_librarian_node(prefix: str, node: object) -> None:
@dataclass
class LibrarianDeps:
wiki_id: str
owner_id: str


librarian = Agent(
Expand Down Expand Up @@ -238,6 +256,7 @@ async def create_page(
context,
title=title,
content=content,
owner_id=run_context.deps.owner_id,
)
return page.state_id

Expand All @@ -253,10 +272,15 @@ async def create_wiki(
) -> UserCreateWikiResponse:
"""Create a new Wiki, record it on the user under
the given name, and kick off its ingest workflow."""
owner_id = (
context.auth.user_id if context.auth is not None and
context.auth.user_id is not None else ""
)
wiki, _ = await Wiki.create(
context,
name=request.name,
description=request.description,
owner_id=owner_id,
)
self.state.wikis[request.name] = wiki.state_id
return UserCreateWikiResponse(wiki_id=wiki.state_id)
Expand Down Expand Up @@ -284,13 +308,17 @@ class WikiServicer(Wiki.Servicer):
title and description metadata, and the librarian
`ingest` workflow that folds transcripts into it."""

def authorizer(self):
return allow_if(any=[_caller_is_owner, is_app_internal])

async def create(
self,
context: WriterContext,
request: Wiki.CreateRequest,
) -> None:
self.state.name = request.name
self.state.description = request.description
self.state.owner_id = request.owner_id
await self.ref().schedule().ingest(context)

async def get(
Expand Down Expand Up @@ -320,6 +348,7 @@ async def add_transcript(
transcript, _ = await Transcript.create(
context,
messages=list(request.messages),
owner_id=self.state.owner_id,
)
self.state.transcripts[transcript.state_id] = False
return WikiAddTranscriptResponse(
Expand All @@ -337,6 +366,8 @@ async def ingest(
Pages, then mark the transcript ingested."""
wiki = Wiki.ref()
wiki_id = wiki.state_id
initial_state = await wiki.read(context)
owner_id = initial_state.owner_id

async for _ in context.loop("Ingest loop"):

Expand Down Expand Up @@ -375,7 +406,10 @@ async def run_librarian() -> str:
async with librarian.iter(
context,
prompt,
deps=LibrarianDeps(wiki_id=wiki_id),
deps=LibrarianDeps(
wiki_id=wiki_id,
owner_id=owner_id,
),
) as run:
async for node in run:
_log_librarian_node(log_prefix, node)
Expand Down Expand Up @@ -410,13 +444,17 @@ class PageServicer(Page.Servicer):
"""Servicer for an individual Page: a markdown body with
a title."""

def authorizer(self):
return allow_if(any=[_caller_is_owner, is_app_internal])

async def create(
self,
context: WriterContext,
request: Page.CreateRequest,
) -> None:
self.state.title = request.title
self.state.content = request.content
self.state.owner_id = request.owner_id

async def get(
self,
Expand All @@ -440,12 +478,16 @@ class TranscriptServicer(Transcript.Servicer):
"""Servicer for an individual Transcript (raw conversation
transcript)."""

def authorizer(self):
return allow_if(any=[_caller_is_owner, is_app_internal])

async def create(
self,
context: WriterContext,
request: Transcript.CreateRequest,
) -> None:
self.state.messages = list(request.messages)
self.state.owner_id = request.owner_id

async def get(
self,
Expand Down
Loading