diff --git a/DashAI/alembic/versions/3db684f4090a_merge_datafile_and_dataset_heads.py b/DashAI/alembic/versions/3db684f4090a_merge_datafile_and_dataset_heads.py new file mode 100644 index 000000000..936f5ed0f --- /dev/null +++ b/DashAI/alembic/versions/3db684f4090a_merge_datafile_and_dataset_heads.py @@ -0,0 +1,26 @@ +"""merge datafile and dataset heads + +Revision ID: 3db684f4090a +Revises: a1f8e3b0c2d9, c3d7a1f05e8b +Create Date: 2026-05-27 15:49:47.570864 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '3db684f4090a' +down_revision: Union[str, None] = ('a1f8e3b0c2d9', 'c3d7a1f05e8b') +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + pass + + +def downgrade() -> None: + pass diff --git a/DashAI/alembic/versions/a1c3e5f7b9d2_add_hub_download_table.py b/DashAI/alembic/versions/a1c3e5f7b9d2_add_hub_download_table.py new file mode 100644 index 000000000..e3d20d670 --- /dev/null +++ b/DashAI/alembic/versions/a1c3e5f7b9d2_add_hub_download_table.py @@ -0,0 +1,47 @@ +"""Add datafile table + +Revision ID: a1c3e5f7b9d2 +Revises: b4f9e70098e7 +Create Date: 2026-05-08 00:00:00.000000 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "a1c3e5f7b9d2" +down_revision: Union[str, None] = "b4f9e70098e7" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + "datafile", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("source_name", sa.String(), nullable=False), + sa.Column("dataset_id", sa.String(), nullable=False), + sa.Column("name", sa.String(), nullable=False), + sa.Column("local_path", sa.String(), nullable=True), + sa.Column( + "status", + sa.Enum("downloading", "ready", "error", name="hubdownloadstatus"), + nullable=False, + ), + sa.Column("error_message", sa.String(), nullable=True), + sa.Column("created", sa.DateTime(), nullable=False), + sa.Column("last_modified", sa.DateTime(), nullable=False), + sa.PrimaryKeyConstraint("id", name=op.f("pk_datafile")), + sa.UniqueConstraint( + "source_name", + "dataset_id", + name="uq_datafile_source_dataset", + ), + ) + + +def downgrade() -> None: + op.drop_table("datafile") diff --git a/DashAI/alembic/versions/a1f8e3b0c2d9_add_total_rows_columns_to_dataset.py b/DashAI/alembic/versions/a1f8e3b0c2d9_add_total_rows_columns_to_dataset.py new file mode 100644 index 000000000..f91636fb1 --- /dev/null +++ b/DashAI/alembic/versions/a1f8e3b0c2d9_add_total_rows_columns_to_dataset.py @@ -0,0 +1,27 @@ +"""Add total_rows and total_columns to dataset + +Revision ID: a1f8e3b0c2d9 +Revises: b4f9e70098e7 +Create Date: 2026-05-14 00:00:00.000000 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +revision: str = "a1f8e3b0c2d9" +down_revision: Union[str, None] = "b4f9e70098e7" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column("dataset", sa.Column("total_rows", sa.Integer(), nullable=True)) + op.add_column("dataset", sa.Column("total_columns", sa.Integer(), nullable=True)) + + +def downgrade() -> None: + op.drop_column("dataset", "total_columns") + op.drop_column("dataset", "total_rows") diff --git a/DashAI/alembic/versions/c3d7a1f05e8b_add_metadata_to_datafile.py b/DashAI/alembic/versions/c3d7a1f05e8b_add_metadata_to_datafile.py new file mode 100644 index 000000000..fea1ca339 --- /dev/null +++ b/DashAI/alembic/versions/c3d7a1f05e8b_add_metadata_to_datafile.py @@ -0,0 +1,32 @@ +"""Add metadata columns to datafile table + +Revision ID: c3d7a1f05e8b +Revises: a1c3e5f7b9d2 +Create Date: 2026-05-12 00:00:00.000000 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "c3d7a1f05e8b" +down_revision: Union[str, None] = "a1c3e5f7b9d2" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column("datafile", sa.Column("size_bytes", sa.BigInteger(), nullable=True)) + op.add_column("datafile", sa.Column("description", sa.Text(), nullable=True)) + op.add_column("datafile", sa.Column("tags", sa.Text(), nullable=True)) + op.add_column("datafile", sa.Column("source_url", sa.Text(), nullable=True)) + + +def downgrade() -> None: + op.drop_column("datafile", "source_url") + op.drop_column("datafile", "tags") + op.drop_column("datafile", "description") + op.drop_column("datafile", "size_bytes") diff --git a/DashAI/back/api/api_v1/api.py b/DashAI/back/api/api_v1/api.py index 6a3476db2..d3fd89fb9 100644 --- a/DashAI/back/api/api_v1/api.py +++ b/DashAI/back/api/api_v1/api.py @@ -2,6 +2,8 @@ from DashAI.back.api.api_v1.endpoints.components import router as components from DashAI.back.api.api_v1.endpoints.converters import router as converters +from DashAI.back.api.api_v1.endpoints.datafile import router as datafile_router +from DashAI.back.api.api_v1.endpoints.dataset_source import router as dataset_source from DashAI.back.api.api_v1.endpoints.datasets import router as datasets from DashAI.back.api.api_v1.endpoints.explainers import router as explainers from DashAI.back.api.api_v1.endpoints.explorers import router as explorers @@ -40,3 +42,5 @@ api_router_v1.include_router(metrics, prefix="/metrics") api_router_v1.include_router(hardware, prefix="/hardware") api_router_v1.include_router(scoring, prefix="/scoring") +api_router_v1.include_router(dataset_source, prefix="/dataset-source") +api_router_v1.include_router(datafile_router, prefix="/datafile") diff --git a/DashAI/back/api/api_v1/endpoints/components.py b/DashAI/back/api/api_v1/endpoints/components.py index ab701e2e9..3c2d5cc53 100644 --- a/DashAI/back/api/api_v1/endpoints/components.py +++ b/DashAI/back/api/api_v1/endpoints/components.py @@ -353,10 +353,14 @@ async def get_component_image( .get("metadata", {}) .get("image_preview", None) ) + cache_headers = {"Cache-Control": "public, max-age=3600"} + if not image_path: with open(local_path_image / "placeholder.svg", "rb") as image_file: return StreamingResponse( - io.BytesIO(image_file.read()), media_type="image/svg+xml" + io.BytesIO(image_file.read()), + media_type="image/svg+xml", + headers=cache_headers, ) # If it is a URL, we obtain the image from the URL @@ -364,22 +368,30 @@ async def get_component_image( response = requests.get(image_path, timeout=5) if response.status_code == 200: return StreamingResponse( - io.BytesIO(response.content), media_type="image/png" + io.BytesIO(response.content), + media_type="image/png", + headers=cache_headers, ) else: with open(local_path_image / "placeholder.svg", "rb") as image_file: return StreamingResponse( - io.BytesIO(image_file.read()), media_type="image/svg+xml" + io.BytesIO(image_file.read()), + media_type="image/svg+xml", + headers=cache_headers, ) # Otherwise, we assume it is a local path try: with open(local_path_image / image_path, "rb") as image_file: return StreamingResponse( - io.BytesIO(image_file.read()), media_type="image/png" + io.BytesIO(image_file.read()), + media_type="image/png", + headers=cache_headers, ) except FileNotFoundError: with open(local_path_image / "placeholder.svg", "rb") as image_file: return StreamingResponse( - io.BytesIO(image_file.read()), media_type="image/svg+xml" + io.BytesIO(image_file.read()), + media_type="image/svg+xml", + headers=cache_headers, ) diff --git a/DashAI/back/api/api_v1/endpoints/converters.py b/DashAI/back/api/api_v1/endpoints/converters.py index db5563817..7d1507c60 100644 --- a/DashAI/back/api/api_v1/endpoints/converters.py +++ b/DashAI/back/api/api_v1/endpoints/converters.py @@ -229,15 +229,23 @@ async def delete_converter( notebook.file_path, dirs_exist_ok=True, ) + # copytree preserves the source mtime, which may predate the API + # cache entry. Touch data.arrow so the cache invalidation detects + # the change even when no previous converters are re-run. + import os + import time + + arrow_path = os.path.join(notebook.file_path, "dataset", "data.arrow") + if os.path.exists(arrow_path): + now = time.time() + os.utime(arrow_path, (now, now)) # Enqueue all previous converters job_ids = [] for converter in previous_converters: - # Crear instancia de ConverterJob y encolarlo directamente job = ConverterJob(converter_id=converter.id) - job_queue.put(job) - if hasattr(job, "id"): - job_ids.append(job.id) + enqueued = job_queue.put(job) + job_ids.append(enqueued.id) # Delete all the converters after the current one for converter in next_converters: diff --git a/DashAI/back/api/api_v1/endpoints/datafile.py b/DashAI/back/api/api_v1/endpoints/datafile.py new file mode 100644 index 000000000..b71ebf4de --- /dev/null +++ b/DashAI/back/api/api_v1/endpoints/datafile.py @@ -0,0 +1,217 @@ +"""Datafile management endpoints.""" + +import json +import logging +import os +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict, List + +from fastapi import APIRouter, Depends, status +from fastapi.exceptions import HTTPException +from kink import di +from pydantic import BaseModel +from sqlalchemy import exc + +from DashAI.back.core.enums.status import DatafileStatus +from DashAI.back.dependencies.database.models import Datafile + +if TYPE_CHECKING: + from sqlalchemy.orm import sessionmaker + + from DashAI.back.dependencies.registry import ComponentRegistry + +log = logging.getLogger(__name__) +router = APIRouter() + + +def _row_to_dict(row: Datafile) -> Dict[str, Any]: + return { + "id": row.id, + "source_name": row.source_name, + "dataset_id": row.dataset_id, + "name": row.name, + "local_path": row.local_path, + "status": row.status.value, + "error_message": row.error_message, + "size_bytes": row.size_bytes, + "description": row.description, + "tags": json.loads(row.tags) if row.tags else [], + "source_url": row.source_url, + "created": row.created.isoformat() if row.created else None, + "last_modified": row.last_modified.isoformat() if row.last_modified else None, + } + + +class CreateDownloadRequest(BaseModel): + source_name: str + dataset_id: str + name: str + description: str = "" + tags: list[str] = [] + source_url: str = "" + + +@router.get("/", response_model=List[Dict[str, Any]]) +async def list_downloads( + session_factory: "sessionmaker" = Depends(lambda: di["session_factory"]), +) -> List[Dict[str, Any]]: + """Return all datafile records.""" + with session_factory() as db: + rows = db.query(Datafile).order_by(Datafile.created.asc()).all() + return [_row_to_dict(r) for r in rows] + + +@router.post("/", status_code=status.HTTP_201_CREATED, response_model=Dict[str, Any]) +async def create_download( + body: CreateDownloadRequest, + session_factory: "sessionmaker" = Depends(lambda: di["session_factory"]), + registry: "ComponentRegistry" = Depends(lambda: di["component_registry"]), +) -> Dict[str, Any]: + """Create a Datafile record. + + If a record for (source_name, dataset_id) already exists and its status is + READY or DOWNLOADING, it is returned immediately. If status is ERROR the + record is reset to DOWNLOADING so the caller can re-enqueue the job. + """ + sources = registry._registry.get("DatasetSource", {}) + if body.source_name not in sources: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"DatasetSource '{body.source_name}' not found.", + ) + + with session_factory() as db: + existing = ( + db.query(Datafile) + .filter( + Datafile.source_name == body.source_name, + Datafile.dataset_id == body.dataset_id, + ) + .first() + ) + if existing is not None: + if existing.status == DatafileStatus.READY: + return _row_to_dict(existing) + if existing.status == DatafileStatus.DOWNLOADING: + return _row_to_dict(existing) + # ERROR — allow retry: reset to downloading + existing.status = DatafileStatus.DOWNLOADING + existing.error_message = None + existing.local_path = None + existing.name = body.name + existing.description = body.description + existing.tags = json.dumps(body.tags) + existing.source_url = body.source_url or None + try: + db.commit() + db.refresh(existing) + except exc.SQLAlchemyError as e: + log.exception(e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="DB error resetting download.", + ) from e + row = existing + else: + row = Datafile( + source_name=body.source_name, + dataset_id=body.dataset_id, + name=body.name, + description=body.description, + tags=json.dumps(body.tags), + source_url=body.source_url or None, + status=DatafileStatus.DOWNLOADING, + ) + db.add(row) + try: + db.commit() + db.refresh(row) + except exc.SQLAlchemyError as e: + log.exception(e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="DB error creating download record.", + ) from e + + return _row_to_dict(row) + + +@router.get("/{datafile_id}", response_model=Dict[str, Any]) +async def get_download( + datafile_id: int, + session_factory: "sessionmaker" = Depends(lambda: di["session_factory"]), +) -> Dict[str, Any]: + """Return a single datafile record by id.""" + with session_factory() as db: + row = db.get(Datafile, datafile_id) + if row is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Datafile {datafile_id} not found.", + ) + return _row_to_dict(row) + + +@router.delete("/{datafile_id}", status_code=status.HTTP_204_NO_CONTENT) +async def delete_download( + datafile_id: int, + session_factory: "sessionmaker" = Depends(lambda: di["session_factory"]), +) -> None: + """Delete a datafile record and its cached files.""" + import shutil + + with session_factory() as db: + row = db.get(Datafile, datafile_id) + if row is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Datafile {datafile_id} not found.", + ) + local_path = row.local_path + try: + db.delete(row) + db.commit() + except exc.SQLAlchemyError as e: + log.exception(e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="DB error deleting download record.", + ) from e + + if local_path and os.path.exists(local_path): + shutil.rmtree(local_path, ignore_errors=True) + + +@router.get("/{datafile_id}/files", response_model=List[str]) +async def list_files( + datafile_id: int, + session_factory: "sessionmaker" = Depends(lambda: di["session_factory"]), +) -> List[str]: + """Return the list of files in a ready datafile directory.""" + with session_factory() as db: + row = db.get(Datafile, datafile_id) + if row is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Datafile {datafile_id} not found.", + ) + if row.status != DatafileStatus.READY or not row.local_path: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail="Download is not ready yet.", + ) + local_path = row.local_path + + path = Path(local_path) + if not path.exists(): + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Download directory not found on disk.", + ) + files = sorted( + str(p.relative_to(path)) + for p in path.rglob("*") + if p.is_file() + and not any(part.startswith(".") for part in p.relative_to(path).parts) + ) + return files diff --git a/DashAI/back/api/api_v1/endpoints/dataset_source.py b/DashAI/back/api/api_v1/endpoints/dataset_source.py new file mode 100644 index 000000000..15c98b36a --- /dev/null +++ b/DashAI/back/api/api_v1/endpoints/dataset_source.py @@ -0,0 +1,348 @@ +"""Dataset source API endpoints.""" + +import logging +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict +from urllib.parse import unquote + +from fastapi import APIRouter, Depends, Query, status +from fastapi.exceptions import HTTPException +from kink import di +from pydantic import BaseModel + +from DashAI.back.types.inf.type_inference import infer_types + +if TYPE_CHECKING: + from DashAI.back.dependencies.registry import ComponentRegistry + +log = logging.getLogger(__name__) +router = APIRouter() + + +def _get_source(source_name: str, registry: "ComponentRegistry"): + """Retrieve and instantiate a DatasetSource from the registry. + + Parameters + ---------- + source_name : str + Registered class name of the DatasetSource. + registry : ComponentRegistry + The component registry to look up. + + Returns + ------- + BaseDatasetSource + Instantiated source object. + + Raises + ------ + HTTPException + 404 if source_name is not found in the DatasetSource registry. + """ + sources = registry._registry.get("DatasetSource", {}) + if source_name not in sources: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"DatasetSource '{source_name}' not found.", + ) + return sources[source_name]["class"]() + + +@router.get("/{source_name}/search") +async def search_datasets( + source_name: str, + q: str = Query(default="", description="Search query"), + limit: int = Query(default=20, ge=1, le=100), + cursor: str = Query(default="", description="Pagination cursor from previous page"), + registry: "ComponentRegistry" = Depends(lambda: di["component_registry"]), +) -> Dict[str, Any]: + """Search for datasets in a registered source. + + Parameters + ---------- + source_name : str + Registered DatasetSource class name. + q : str + Search query string. + limit : int + Maximum number of results (1-100). + cursor : str + Opaque pagination token returned by the previous call. Empty string + means first page. + registry : ComponentRegistry + Injected component registry. + + Returns + ------- + dict + ``{"results": [...], "next_cursor": str | null}`` + """ + source = _get_source(source_name, registry) + page = source.search(q, limit=limit, cursor=cursor or None) + return { + "results": [ + { + "id": e.id, + "name": e.name, + "description": e.description, + "tags": e.tags, + "size_bytes": e.size_bytes, + "url": e.url, + "source": e.source, + } + for e in page.entries + ], + "next_cursor": page.next_cursor, + } + + +@router.get("/{source_name}/{dataset_id:path}/info") +async def get_dataset_info( + source_name: str, + dataset_id: str, + registry: "ComponentRegistry" = Depends(lambda: di["component_registry"]), +) -> Dict[str, Any]: + """Return full metadata for a single dataset (description, tags, etc.). + + Parameters + ---------- + source_name : str + Registered DatasetSource class name. + dataset_id : str + Source-specific dataset identifier (URL-encoded). + registry : ComponentRegistry + Injected component registry. + + Returns + ------- + dict + DatasetEntry fields, or empty dict if the source has no enrichment. + """ + source = _get_source(source_name, registry) + decoded_id = unquote(dataset_id) + entry = source.get_info(decoded_id) + if entry is None: + return {} + return { + "id": entry.id, + "description": entry.description, + "tags": entry.tags, + "size_bytes": entry.size_bytes, + } + + +class PreviewRequest(BaseModel): + """Request body for previewing a dataset with dataloader params. + + Parameters + ---------- + dataloader : str | None + Name of the DataLoader to use for parsing the file. + params : dict + DataLoader parameters (e.g., separator for CSV). + n_rows : int + Number of rows to sample (1-500). + datafile_id : int | None + If set, use this pre-downloaded local file instead of fetching from source. + selected_file : str | None + Relative filename inside the datafile directory. + """ + + dataloader: str | None = None + params: Dict[str, Any] = {} + n_rows: int = 100 + datafile_id: int | None = None + selected_file: str | None = None + + +@router.post("/{source_name}/{dataset_id:path}/preview") +async def preview_dataset_with_params( + source_name: str, + dataset_id: str, + body: PreviewRequest, + registry: "ComponentRegistry" = Depends(lambda: di["component_registry"]), + session_factory=Depends(lambda: di["session_factory"]), +) -> Dict[str, Any]: + """Fetch a sample preview using a DataLoader and params. + + If ``hub_download_id`` is provided the already-downloaded local file is + used directly — no re-download from the source occurs. + + Parameters + ---------- + source_name : str + Registered DatasetSource class name. + dataset_id : str + Source-specific dataset identifier (URL-encoded). + body : PreviewRequest + DataLoader name, params, row count, and optional datafile_id. + registry : ComponentRegistry + Injected component registry. + session_factory + Injected DB session factory (used when datafile_id is set). + + Returns + ------- + dict + ``{"sample": [...], "inferred_types": {...}, "preview_row_count": int}``. + """ + _get_source(source_name, registry) # validate source exists + decoded_id = unquote(dataset_id) + n_rows = max(1, min(body.n_rows, 500)) + + try: + if body.datafile_id is None: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="datafile_id is required.", + ) + + from DashAI.back.core.enums.status import DatafileStatus + from DashAI.back.dependencies.database.models import Datafile + + with session_factory() as db: + hub_row = db.get(Datafile, body.datafile_id) + if hub_row is None or hub_row.status != DatafileStatus.READY: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Hub download not ready or not found.", + ) + if body.selected_file: + file_path = str(Path(hub_row.local_path) / body.selected_file) + else: + base_path = Path(hub_row.local_path) + files = sorted( + str(p) + for p in base_path.rglob("*") + if p.is_file() + and not any( + part.startswith(".") for part in p.relative_to(base_path).parts + ) + ) + if not files: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="No files found in hub download directory.", + ) + file_path = files[0] + work_dir = str(Path(file_path).parent) + dataloader_name = body.dataloader + if not dataloader_name: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="dataloader is required.", + ) + + dl_registry = registry._registry.get("DataLoader", {}) + if dataloader_name not in dl_registry: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"DataLoader '{dataloader_name}' not found.", + ) + + dataloader = dl_registry[dataloader_name]["class"]() + params = body.params or {} + + try: + preview_df = dataloader.load_preview( + filepath_or_buffer=file_path, + params=params, + n_rows=n_rows, + ) + except NotImplementedError: + dataset = dataloader.load_data( + filepath_or_buffer=file_path, + temp_path=work_dir, + params=params, + n_sample=n_rows, + ) + preview_df = dataset.to_pandas().head(n_rows) + + except HTTPException: + raise + except Exception as exc: + log.exception("Error fetching preview for %s/%s", source_name, decoded_id) + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail=f"Failed to fetch preview from source: {exc}", + ) from exc + + if preview_df.empty: + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail=f"Source returned no data for dataset '{decoded_id}'.", + ) + + inferred = infer_types(preview_df, method="DashAIPtype") + sample = preview_df.to_dict(orient="records") + + return { + "sample": sample, + "inferred_types": inferred, + "preview_row_count": len(preview_df), + } + + +class ImportRequest(BaseModel): + """Request body for the dataset import endpoint. + + Parameters + ---------- + dataset_id : int + ID of a pre-created Dataset DB record to populate. + params : dict + Parameters including ``inferred_types`` and ``column_renames``. + """ + + dataset_id: int + params: Dict[str, Any] = {} + + +@router.post( + "/{source_name}/{dataset_id:path}/import", status_code=status.HTTP_201_CREATED +) +async def import_dataset( + source_name: str, + dataset_id: str, + body: ImportRequest, + registry: "ComponentRegistry" = Depends(lambda: di["component_registry"]), + job_queue=Depends(lambda: di["job_queue"]), +) -> Dict[str, Any]: + """Enqueue a DatasetJob to import a dataset from an external source. + + Parameters + ---------- + source_name : str + Registered DatasetSource class name. + dataset_id : str + Source-specific dataset identifier (URL-encoded). + body : ImportRequest + Contains the DashAI dataset_id and params. + registry : ComponentRegistry + Injected component registry. + job_queue : BaseJobQueue + Injected job queue. + + Returns + ------- + dict + ``{"job_id": int, "dataset_id": int}`` — the enqueued job and dataset IDs. + """ + from DashAI.back.job.dataset_job import DatasetJob + + _get_source(source_name, registry) # validates source exists, raises 404 if not + + job = DatasetJob( + kwargs={ + "dataset_id": body.dataset_id, + "source_name": source_name, + "dataset_source_id": unquote(dataset_id), + "params": body.params, + } + ) + job.set_status_as_delivered() + result = job_queue.put(job) + # huey.api.Result has .id (task UUID string); plain int in other modes + job_id = getattr(result, "id", result) + + return {"job_id": job_id, "dataset_id": body.dataset_id} diff --git a/DashAI/back/api/api_v1/endpoints/datasets.py b/DashAI/back/api/api_v1/endpoints/datasets.py index 5ea732ac6..6d6f1edb5 100644 --- a/DashAI/back/api/api_v1/endpoints/datasets.py +++ b/DashAI/back/api/api_v1/endpoints/datasets.py @@ -1,10 +1,13 @@ +import asyncio import hashlib +import io import logging +import os import time import zipfile from collections import OrderedDict from datetime import datetime, timezone -from typing import TYPE_CHECKING, Any, Dict +from typing import TYPE_CHECKING, Any, Dict, Optional import pyarrow as pa from fastapi import APIRouter, Depends, File, Form, Query, Response, UploadFile, status @@ -29,6 +32,90 @@ from DashAI.back.dependencies.registry import ComponentRegistry +def _build_image_preview_sample( + extract_dir: str, image_extensions: set, max_rows: int = 5 +) -> tuple: + """Walk an extracted imagefolder directory and return sample rows, + total count, and whether class subdirectories exist. + + Returns + ------- + tuple of (list[dict], int, bool) + (sample_rows, total_images, has_labels) + """ + import os + + samples = [] + total = 0 + labels_seen = set() + for root, dirs, files in os.walk(extract_dir): + dirs[:] = [ + d for d in dirs if not d.startswith("__MACOSX") and not d.startswith(".") + ] + for f in sorted(files): + if f.startswith("."): + continue + ext = os.path.splitext(f)[1].lower() + if ext not in image_extensions: + continue + total += 1 + label = os.path.basename(root) + labels_seen.add(label) + if len(samples) >= max_rows: + continue + filepath = os.path.join(root, f) + try: + with open(filepath, "rb") as fh: + img_bytes = fh.read() + if len(img_bytes) == 0: + continue + thumb = _image_bytes_to_thumbnail_data_uri(img_bytes) + if thumb == "[Image]": + continue + samples.append({"image": thumb, "_label": label}) + except Exception: + continue + + has_labels = len(labels_seen) > 1 + + if has_labels: + for s in samples: + s["label"] = s.pop("_label") + else: + for s in samples: + s.pop("_label", None) + + return samples, total, has_labels + + +def _image_bytes_to_thumbnail_data_uri(img_bytes: bytes, max_size: int = 64) -> str: + """Convert raw image bytes to a small base64 data URI thumbnail.""" + import base64 + import io + + from PIL import Image + + try: + img = Image.open(io.BytesIO(img_bytes)) + if img.mode in ("CMYK", "YCbCr", "LAB", "HSV"): + img = img.convert("RGB") + elif img.mode in ("LA", "PA"): + img = img.convert("RGBA") + img.thumbnail((max_size, max_size)) + buf = io.BytesIO() + img.save(buf, format="PNG") + b64 = base64.b64encode(buf.getvalue()).decode() + return f"data:image/png;base64,{b64}" + except Exception as e: + logger.warning( + "Failed to generate thumbnail (len=%d, head=%r): %s", + len(img_bytes), + img_bytes[:16], + e, + ) + return "[Image]" + + logger = logging.getLogger(__name__) router = APIRouter() @@ -62,6 +149,13 @@ def get(self, path: str, filter_model: str | None, sort_model: str | None): if time.time() - ts > self._ttl: del self._store[key] return None + arrow_file_path = f"{path}/dataset/data.arrow" + try: + if os.path.getmtime(arrow_file_path) > ts: + del self._store[key] + return None + except OSError: + pass self._store.move_to_end(key) return table, total @@ -106,7 +200,7 @@ def _load_and_filter_table( from DashAI.back.dataloaders.classes.dashai_dataset import get_dataset_info arrow_file_path = f"{path}/dataset/data.arrow" - with pa.memory_map(arrow_file_path, "r") as source: + with pa.OSFile(arrow_file_path, "rb") as source: reader = ipc.RecordBatchFileReader(source) batches = [reader.get_batch(i) for i in range(reader.num_record_batches)] table = pa.Table.from_batches(batches) @@ -277,10 +371,35 @@ async def filter_dataset_file( start = page * page_size paged_table = table.slice(start, page_size) - rows = [ - {col: paged_table[col][i].as_py() for col in paged_table.schema.names} - for i in range(paged_table.num_rows) - ] + + image_cols = { + col + for col in paged_table.schema.names + if pa.types.is_struct(paged_table.schema.field(col).type) + or pa.types.is_large_binary(paged_table.schema.field(col).type) + or pa.types.is_binary(paged_table.schema.field(col).type) + } + + rows = [] + for i in range(paged_table.num_rows): + row = {} + for col in paged_table.schema.names: + val = paged_table[col][i].as_py() + if col in image_cols: + if isinstance(val, dict): + img_bytes = val.get("bytes", b"") + elif isinstance(val, bytes): + img_bytes = val + else: + img_bytes = b"" + row[col] = ( + _image_bytes_to_thumbnail_data_uri(img_bytes) + if img_bytes + else "[Image]" + ) + else: + row[col] = val + rows.append(row) return JSONResponse(content={"rows": rows, "total": total}) @@ -1360,7 +1479,7 @@ async def get_dataset_file( end = start + page_size rows_collected = 0 - with pa.memory_map(arrow_file_path, "r") as source: + with pa.OSFile(arrow_file_path, "rb") as source: reader = ipc.RecordBatchFileReader(source) current_index = 0 @@ -1380,11 +1499,32 @@ async def get_dataset_file( slice_end = min(batch.num_rows, end - batch_start) sliced_batch = batch.slice(slice_start, slice_end - slice_start) + image_cols = { + col + for col in sliced_batch.schema.names + if pa.types.is_struct(sliced_batch.schema.field(col).type) + or pa.types.is_large_binary(sliced_batch.schema.field(col).type) + or pa.types.is_binary(sliced_batch.schema.field(col).type) + } + for j in range(sliced_batch.num_rows): - row = { - col: sliced_batch[col][j].as_py() - for col in sliced_batch.schema.names - } + row = {} + for col in sliced_batch.schema.names: + val = sliced_batch[col][j].as_py() + if col in image_cols: + if isinstance(val, dict): + img_bytes = val.get("bytes", b"") + elif isinstance(val, bytes): + img_bytes = val + else: + img_bytes = b"" + row[col] = ( + _image_bytes_to_thumbnail_data_uri(img_bytes) + if img_bytes + else "[Image]" + ) + else: + row[col] = val rows.append(row) rows_collected += 1 if rows_collected >= page_size: @@ -1398,72 +1538,180 @@ async def get_dataset_file( return JSONResponse(content={"rows": rows, "total": total_rows}) +def _build_image_zip(table: "pa.Table") -> io.BytesIO: + """Build a ZIP buffer from an image dataset table. + + Uses ZIP_STORED (no compression) because image formats (JPEG, PNG) are + already compressed — DEFLATE gains nothing but wastes significant CPU time. + """ + label_col = next( + ( + col + for col in table.column_names + if not pa.types.is_struct(table.schema.field(col).type) + and ( + pa.types.is_string(table.schema.field(col).type) + or pa.types.is_large_string(table.schema.field(col).type) + or pa.types.is_dictionary(table.schema.field(col).type) + ) + ), + None, + ) + + image_cols = [ + col + for col in table.column_names + if pa.types.is_struct(table.schema.field(col).type) + ] + + zip_buffer = io.BytesIO() + with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_STORED) as zf: + seen_paths: set = set() + for col in image_cols: + arr_col = table.column(col) + label_arr = table.column(label_col) if label_col else None + for i in range(table.num_rows): + val = arr_col[i].as_py() + if not (isinstance(val, dict) and val.get("bytes")): + continue + img_bytes = val["bytes"] + raw_path = val.get("path") or "" + raw_ext = os.path.splitext(raw_path)[1].lstrip(".").lower() + ext = raw_ext if raw_ext else "png" + orig_fname = os.path.basename(raw_path) or f"{col}_{i}.{ext}" + + if label_arr is not None: + label_val = str(label_arr[i].as_py()) + entry = f"{label_val}/{orig_fname}" + else: + entry = f"images/{orig_fname}" + + if entry in seen_paths: + stem, dot_ext = os.path.splitext(orig_fname) + entry_base = ( + f"{label_val}/{stem}" + if label_arr is not None + else f"images/{stem}" + ) + entry = f"{entry_base}_{i}{dot_ext}" + seen_paths.add(entry) + zf.writestr(entry, img_bytes) + + zip_buffer.seek(0) + return zip_buffer + + +def _iter_buffer(buf: io.BytesIO, chunk_size: int = 65536): + """Yield a BytesIO in fixed-size chunks for StreamingResponse.""" + while True: + chunk = buf.read(chunk_size) + if not chunk: + break + yield chunk + + +async def _build_export_response( + table: "pa.Table", dataset_name: str +) -> StreamingResponse: + """Build a StreamingResponse for dataset export. + + For image datasets (struct columns), produces a ZIP in ImageFolder format: + ``