Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 49 additions & 16 deletions src/google/adk/artifacts/file_artifact_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,18 +217,20 @@ class FileArtifactService(BaseArtifactService):

# Storage layout matches the cloud and in-memory services:
# root/
# └── users/
# └── {user_id}/
# ├── sessions/
# │ └── {session_id}/
# │ └── artifacts/
# │ └── {artifact_path}/ # derived from filename
# │ └── versions/
# │ └── {version}/
# │ ├── {original_filename}
# │ └── metadata.json
# └── artifacts/
# └── {artifact_path}/...
# └── apps/
# └── {app_name}/
# └── users/
# └── {user_id}/
# ├── sessions/
# │ └── {session_id}/
# │ └── artifacts/
# │ └── {artifact_path}/ # derived from filename
# │ └── versions/
# │ └── {version}/
# │ ├── {original_filename}
# │ └── metadata.json
# └── artifacts/
# └── {artifact_path}/...
#
# Artifact paths are derived from the provided filenames: separators create
# nested directories, and path traversal is rejected to keep the layout
Expand All @@ -244,19 +246,21 @@ def __init__(self, root_dir: Path | str):
self.root_dir = Path(root_dir).expanduser().resolve()
self.root_dir.mkdir(parents=True, exist_ok=True)

def _base_root(self, user_id: str, /) -> Path:
def _base_root(self, app_name: str, user_id: str, /) -> Path:
"""Returns the artifacts root directory for a user."""
_validate_path_segment(app_name, "app_name")
_validate_path_segment(user_id, "user_id")
return self.root_dir / "users" / user_id
return self.root_dir / "apps" / app_name / "users" / user_id

def _scope_root(
self,
app_name: str,
user_id: str,
session_id: Optional[str],
filename: str,
) -> Path:
"""Returns the directory that represents the artifact scope."""
base = self._base_root(user_id)
base = self._base_root(app_name, user_id)
if _is_user_scoped(session_id, filename):
return _user_artifacts_dir(base)
if not session_id:
Expand All @@ -267,12 +271,14 @@ def _scope_root(

def _artifact_dir(
self,
app_name: str,
user_id: str,
session_id: Optional[str],
filename: str,
) -> Path:
"""Builds the directory path for an artifact."""
scope_root = self._scope_root(
app_name=app_name,
user_id=user_id,
session_id=session_id,
filename=filename,
Expand All @@ -283,6 +289,7 @@ def _artifact_dir(
def _build_artifact_version(
self,
*,
app_name: str,
user_id: str,
session_id: Optional[str],
filename: str,
Expand All @@ -294,6 +301,7 @@ def _build_artifact_version(
metadata.canonical_uri
if metadata and metadata.canonical_uri
else self._canonical_uri(
app_name=app_name,
user_id=user_id,
session_id=session_id,
filename=filename,
Expand All @@ -312,13 +320,15 @@ def _build_artifact_version(
def _canonical_uri(
self,
*,
app_name: str,
user_id: str,
session_id: Optional[str],
filename: str,
version: int,
) -> str:
"""Builds the canonical file:// URI for an artifact payload."""
artifact_dir = self._artifact_dir(
app_name=app_name,
user_id=user_id,
session_id=session_id,
filename=filename,
Expand Down Expand Up @@ -357,6 +367,7 @@ async def save_artifact(
"""
return await asyncio.to_thread(
self._save_artifact_sync,
app_name,
user_id,
filename,
artifact,
Expand All @@ -366,6 +377,7 @@ async def save_artifact(

def _save_artifact_sync(
self,
app_name: str,
user_id: str,
filename: str,
artifact: Union[types.Part, dict[str, Any]],
Expand All @@ -375,6 +387,7 @@ def _save_artifact_sync(
"""Saves an artifact to disk and returns its version."""
artifact = ensure_part(artifact)
artifact_dir = self._artifact_dir(
app_name=app_name,
user_id=user_id,
session_id=session_id,
filename=filename,
Expand Down Expand Up @@ -407,6 +420,7 @@ def _save_artifact_sync(
)

canonical_uri = self._canonical_uri(
app_name=app_name,
user_id=user_id,
session_id=session_id,
filename=filename,
Expand Down Expand Up @@ -441,6 +455,7 @@ async def load_artifact(
) -> Optional[types.Part]:
return await asyncio.to_thread(
self._load_artifact_sync,
app_name,
user_id,
filename,
session_id,
Expand All @@ -449,13 +464,15 @@ async def load_artifact(

def _load_artifact_sync(
self,
app_name: str,
user_id: str,
filename: str,
session_id: Optional[str],
version: Optional[int],
) -> Optional[types.Part]:
"""Loads an artifact from disk."""
artifact_dir = self._artifact_dir(
app_name=app_name,
user_id=user_id,
session_id=session_id,
filename=filename,
Expand Down Expand Up @@ -510,19 +527,21 @@ async def list_artifact_keys(
) -> list[str]:
return await asyncio.to_thread(
self._list_artifact_keys_sync,
app_name,
user_id,
session_id,
)

def _list_artifact_keys_sync(
self,
app_name: str,
user_id: str,
session_id: Optional[str],
) -> list[str]:
"""Lists artifact filenames for the given session/user."""
filenames: set[str] = set()

base_root = self._base_root(user_id)
base_root = self._base_root(app_name, user_id)

if session_id:
session_root = _session_artifacts_dir(base_root, session_id)
Expand Down Expand Up @@ -565,18 +584,21 @@ async def delete_artifact(
"""
await asyncio.to_thread(
self._delete_artifact_sync,
app_name,
user_id,
filename,
session_id,
)

def _delete_artifact_sync(
self,
app_name: str,
user_id: str,
filename: str,
session_id: Optional[str],
) -> None:
artifact_dir = self._artifact_dir(
app_name=app_name,
user_id=user_id,
session_id=session_id,
filename=filename,
Expand All @@ -597,18 +619,21 @@ async def list_versions(
"""Lists all versions stored for an artifact."""
return await asyncio.to_thread(
self._list_versions_sync,
app_name,
user_id,
filename,
session_id,
)

def _list_versions_sync(
self,
app_name: str,
user_id: str,
filename: str,
session_id: Optional[str],
) -> list[int]:
artifact_dir = self._artifact_dir(
app_name=app_name,
user_id=user_id,
session_id=session_id,
filename=filename,
Expand All @@ -627,18 +652,21 @@ async def list_artifact_versions(
"""Lists metadata for each artifact version on disk."""
return await asyncio.to_thread(
self._list_artifact_versions_sync,
app_name,
user_id,
filename,
session_id,
)

def _list_artifact_versions_sync(
self,
app_name: str,
user_id: str,
filename: str,
session_id: Optional[str],
) -> list[ArtifactVersion]:
artifact_dir = self._artifact_dir(
app_name=app_name,
user_id=user_id,
session_id=session_id,
filename=filename,
Expand All @@ -650,6 +678,7 @@ def _list_artifact_versions_sync(
metadata = _read_metadata(metadata_path)
artifact_versions.append(
self._build_artifact_version(
app_name=app_name,
user_id=user_id,
session_id=session_id,
filename=filename,
Expand All @@ -672,6 +701,7 @@ async def get_artifact_version(
"""Gets metadata for a specific artifact version."""
return await asyncio.to_thread(
self._get_artifact_version_sync,
app_name,
user_id,
filename,
session_id,
Expand All @@ -680,12 +710,14 @@ async def get_artifact_version(

def _get_artifact_version_sync(
self,
app_name: str,
user_id: str,
filename: str,
session_id: Optional[str],
version: Optional[int],
) -> Optional[ArtifactVersion]:
artifact_dir = self._artifact_dir(
app_name=app_name,
user_id=user_id,
session_id=session_id,
filename=filename,
Expand All @@ -703,6 +735,7 @@ def _get_artifact_version_sync(
metadata_path = _metadata_path(artifact_dir, version_to_read)
metadata = _read_metadata(metadata_path)
return self._build_artifact_version(
app_name=app_name,
user_id=user_id,
session_id=session_id,
filename=filename,
Expand Down
Loading