diff --git a/reboot/examples/agent-wiki/api/agent_wiki/v1/wiki.py b/reboot/examples/agent-wiki/api/agent_wiki/v1/wiki.py index d3c974c..262114d 100644 --- a/reboot/examples/agent-wiki/api/agent_wiki/v1/wiki.py +++ b/reboot/examples/agent-wiki/api/agent_wiki/v1/wiki.py @@ -65,11 +65,14 @@ class WikiState(Model): # entries and folds each transcript's material into the # wiki's markdown `content` and into Pages it references. transcripts: dict[str, bool] = Field(tag=5, default_factory=dict) + # The user ID of the user who owns this Wiki. + owner_id: str = Field(tag=6) class WikiCreateRequest(Model): name: str = Field(tag=1) description: str = Field(tag=2) + owner_id: str = Field(tag=3) class WikiGetResponse(Model): @@ -113,11 +116,14 @@ class PageState(Model): # link, call the corresponding type with the referenced # state ID (e.g., `Page.get` on `abc123`). content: str = Field(tag=2, default="") + # The user ID of the user who owns this Page. + owner_id: str = Field(tag=3) class PageCreateRequest(Model): title: str = Field(tag=1) content: str = Field(tag=2) + owner_id: str = Field(tag=3) class PageGetResponse(Model): @@ -135,10 +141,13 @@ class PageUpdateRequest(Model): class TranscriptState(Model): messages: list[TranscriptMessage] = Field(tag=1, default_factory=list) + # The user ID of the user who owns this Transcript. + owner_id: str = Field(tag=2) class TranscriptCreateRequest(Model): messages: list[TranscriptMessage] = Field(tag=1, default_factory=list) + owner_id: str = Field(tag=2) class TranscriptGetResponse(Model): diff --git a/reboot/examples/agent-wiki/backend/src/servicers/wiki.py b/reboot/examples/agent-wiki/backend/src/servicers/wiki.py index 2256525..bde3173 100644 --- a/reboot/examples/agent-wiki/backend/src/servicers/wiki.py +++ b/reboot/examples/agent-wiki/backend/src/servicers/wiki.py @@ -12,7 +12,9 @@ from agent_wiki.v1.wiki_rbt import Page, Transcript, User, Wiki from dataclasses import dataclass from pydantic_ai import RunContext +from rbt.v1alpha1.errors_pb2 import Ok, PermissionDenied, Unauthenticated from reboot.agents.pydantic_ai import Agent +from reboot.aio.auth.authorizers import allow_if, is_app_internal from reboot.aio.contexts import ( ReaderContext, TransactionContext, @@ -20,10 +22,25 @@ WriterContext, ) from reboot.aio.workflows import at_most_once, until +from typing import Union logger = logging.getLogger(__name__) +def _caller_is_owner( + *, + context: ReaderContext, + state: Union[Wiki.State, Transcript.State, Page.State], + **kwargs, +): + """Allow when the caller's `user_id` matches the stored owner ID.""" + if context.auth is None or context.auth.user_id is None: + return Unauthenticated() + if state is not None and context.auth.user_id == state.owner_id: + return Ok() + return PermissionDenied() + + def _truncate(value: object, limit: int = 500) -> str: """Render `value` as a string, shortened for log lines so a big tool payload doesn't flood the output.""" @@ -79,6 +96,7 @@ def _log_librarian_node(prefix: str, node: object) -> None: @dataclass class LibrarianDeps: wiki_id: str + owner_id: str librarian = Agent( @@ -238,6 +256,7 @@ async def create_page( context, title=title, content=content, + owner_id=run_context.deps.owner_id, ) return page.state_id @@ -253,10 +272,15 @@ async def create_wiki( ) -> UserCreateWikiResponse: """Create a new Wiki, record it on the user under the given name, and kick off its ingest workflow.""" + owner_id = ( + context.auth.user_id if context.auth is not None and + context.auth.user_id is not None else "" + ) wiki, _ = await Wiki.create( context, name=request.name, description=request.description, + owner_id=owner_id, ) self.state.wikis[request.name] = wiki.state_id return UserCreateWikiResponse(wiki_id=wiki.state_id) @@ -284,6 +308,9 @@ class WikiServicer(Wiki.Servicer): title and description metadata, and the librarian `ingest` workflow that folds transcripts into it.""" + def authorizer(self): + return allow_if(any=[_caller_is_owner, is_app_internal]) + async def create( self, context: WriterContext, @@ -291,6 +318,7 @@ async def create( ) -> None: self.state.name = request.name self.state.description = request.description + self.state.owner_id = request.owner_id await self.ref().schedule().ingest(context) async def get( @@ -320,6 +348,7 @@ async def add_transcript( transcript, _ = await Transcript.create( context, messages=list(request.messages), + owner_id=self.state.owner_id, ) self.state.transcripts[transcript.state_id] = False return WikiAddTranscriptResponse( @@ -337,6 +366,8 @@ async def ingest( Pages, then mark the transcript ingested.""" wiki = Wiki.ref() wiki_id = wiki.state_id + initial_state = await wiki.read(context) + owner_id = initial_state.owner_id async for _ in context.loop("Ingest loop"): @@ -375,7 +406,10 @@ async def run_librarian() -> str: async with librarian.iter( context, prompt, - deps=LibrarianDeps(wiki_id=wiki_id), + deps=LibrarianDeps( + wiki_id=wiki_id, + owner_id=owner_id, + ), ) as run: async for node in run: _log_librarian_node(log_prefix, node) @@ -410,6 +444,9 @@ class PageServicer(Page.Servicer): """Servicer for an individual Page: a markdown body with a title.""" + def authorizer(self): + return allow_if(any=[_caller_is_owner, is_app_internal]) + async def create( self, context: WriterContext, @@ -417,6 +454,7 @@ async def create( ) -> None: self.state.title = request.title self.state.content = request.content + self.state.owner_id = request.owner_id async def get( self, @@ -440,12 +478,16 @@ class TranscriptServicer(Transcript.Servicer): """Servicer for an individual Transcript (raw conversation transcript).""" + def authorizer(self): + return allow_if(any=[_caller_is_owner, is_app_internal]) + async def create( self, context: WriterContext, request: Transcript.CreateRequest, ) -> None: self.state.messages = list(request.messages) + self.state.owner_id = request.owner_id async def get( self,