diff --git a/DashAI/back/dataloaders/classes/dashai_dataset.py b/DashAI/back/dataloaders/classes/dashai_dataset.py index 82d96b2c9..83a12eafd 100644 --- a/DashAI/back/dataloaders/classes/dashai_dataset.py +++ b/DashAI/back/dataloaders/classes/dashai_dataset.py @@ -637,7 +637,7 @@ def get_split(self, split_name: str) -> "DashAIDataset": new_splits = {"split_indices": {split_name: indices}} arrow_table = subset.arrow_table # with_format("arrow")[:] ####Check - subset = DashAIDataset(arrow_table, splits=new_splits) + subset = DashAIDataset(arrow_table, splits=new_splits, types=self._types) return subset @beartype diff --git a/DashAI/back/models/scikit_learn/sklearn_like_classifier.py b/DashAI/back/models/scikit_learn/sklearn_like_classifier.py index a50af4a5b..919f81a57 100644 --- a/DashAI/back/models/scikit_learn/sklearn_like_classifier.py +++ b/DashAI/back/models/scikit_learn/sklearn_like_classifier.py @@ -36,10 +36,7 @@ def predict(self, x_pred: "DashAIDataset") -> "ndarray": from DashAI.back.dataloaders.classes.dashai_dataset import DashAIDataset if isinstance(x_pred, DashAIDataset): - try: - x_prepared = self.prepare_dataset(x_pred, is_fit=False) - except ValueError: - x_prepared = x_pred + x_prepared = self.prepare_dataset(x_pred, is_fit=False) x_pred = x_prepared.to_pandas() elif isinstance(x_pred, pd.DataFrame): pass diff --git a/DashAI/back/services/scoring_service.py b/DashAI/back/services/scoring_service.py index afb6ed9a2..b787b5917 100644 --- a/DashAI/back/services/scoring_service.py +++ b/DashAI/back/services/scoring_service.py @@ -90,7 +90,12 @@ def _get_all_profiles( return all_profiles - def get_available_profiles(self, task_name: Optional[str]) -> List[Dict[str, Any]]: + @inject + def get_available_profiles( + self, + task_name: Optional[str], + component_registry: "ComponentRegistry" = lambda di: di["component_registry"], + ) -> List[Dict[str, Any]]: """Get profiles available for a given task. Fetches profiles from the task's SCORING_PROFILES class attribute @@ -101,39 +106,49 @@ def get_available_profiles(self, task_name: Optional[str]) -> List[Dict[str, Any task_name : Optional[str] Task class name (e.g., "TabularClassificationTask") or None. If None, returns all profiles from all tasks. + component_registry : ComponentRegistry + Registry to look up task class (injected). Returns ------- List[Dict[str, Any]] List of profile dicts with keys: id, description, weights. """ - all_profiles = self._get_all_profiles() + if task_name: + # Access the task class directly to avoid profile-ID collisions + # that occur when multiple tasks share the same profile ID keys + # (e.g. TabularClassificationTask and ImageClassificationTask both + # define "balanced"). + try: + full_task_dict = component_registry[task_name] + task_class = full_task_dict.get("class") + if task_class and hasattr(task_class, "SCORING_PROFILES"): + return [ + { + "id": profile_id, + "description": profile_data["description"], + "weights": profile_data["weights"], + } + for profile_id, profile_data in ( + task_class.SCORING_PROFILES.items() + ) + ] + except KeyError: + log.debug(f"Task {task_name} not found in registry") + except Exception as e: + log.warning(f"Error accessing SCORING_PROFILES for {task_name}: {e}") + return [] - if not task_name: - # No task specified; return all profiles - return [ - { - "id": profile_id, - "description": profile_data["description"], - "weights": profile_data["weights"], - } - for profile_id, profile_data in all_profiles.items() - ] - - # Filter profiles by task - result = [] - for profile_id, profile_data in all_profiles.items(): - if profile_data.get("task_name") == task_name: - result.append( - { - "id": profile_id, - "description": profile_data["description"], - "weights": profile_data["weights"], - } - ) - - # If no profiles found for this task, return empty (graceful fallback) - return result + # No task specified — return one entry per unique profile ID across all tasks. + all_profiles = self._get_all_profiles() + return [ + { + "id": profile_id, + "description": profile_data["description"], + "weights": profile_data["weights"], + } + for profile_id, profile_data in all_profiles.items() + ] def normalize_metric_value( self, diff --git a/DashAI/front/src/components/models/ModelComparisonTable.jsx b/DashAI/front/src/components/models/ModelComparisonTable.jsx index 78a7d1397..571df06b4 100644 --- a/DashAI/front/src/components/models/ModelComparisonTable.jsx +++ b/DashAI/front/src/components/models/ModelComparisonTable.jsx @@ -114,12 +114,20 @@ function ModelComparisonTable({ fetchProfiles(); }, [session?.task_name]); + // Stable string that changes only when a run's status changes. + // Used as a dep so the score fetch re-triggers after training completes + // without firing on every unrelated re-render of the parent. + const runStatusSignature = useMemo( + () => initialRuns.map((r) => `${r.id}:${r.status}`).join(","), + [initialRuns], + ); + // ──────────────────────────────────────────────────────────────────────── - // Fetch scores when profile or runs change + // Fetch scores when profile, split, session or any run status changes // ──────────────────────────────────────────────────────────────────────── useEffect(() => { - if (!runs.length || !selectedProfile || !session?.id) return; + if (!initialRuns.length || !selectedProfile || !session?.id) return; const fetchScores = async () => { setLoadingScores(true); @@ -152,7 +160,7 @@ function ModelComparisonTable({ }; fetchScores(); - }, [selectedProfile, metricSplit, session?.id]); + }, [selectedProfile, metricSplit, session?.id, runStatusSignature]); // ──────────────────────────────────────────────────────────────────────── // Build columns @@ -297,8 +305,8 @@ function ModelComparisonTable({ ), Cell: ({ row }) => { - const { statusCode, id } = row.original; - const isRunning = statusCode === 1 || statusCode === 2; + const { status, id } = row.original; + const isRunning = status === 1 || status === 2; if (isRunning) return "-"; const scoreData = scores[id];