Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DashAI/back/dataloaders/classes/dashai_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,7 @@ def get_split(self, split_name: str) -> "DashAIDataset":

new_splits = {"split_indices": {split_name: indices}}
arrow_table = subset.arrow_table # with_format("arrow")[:] ####Check
subset = DashAIDataset(arrow_table, splits=new_splits)
subset = DashAIDataset(arrow_table, splits=new_splits, types=self._types)
return subset

@beartype
Expand Down
5 changes: 1 addition & 4 deletions DashAI/back/models/scikit_learn/sklearn_like_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,7 @@ def predict(self, x_pred: "DashAIDataset") -> "ndarray":
from DashAI.back.dataloaders.classes.dashai_dataset import DashAIDataset

if isinstance(x_pred, DashAIDataset):
try:
x_prepared = self.prepare_dataset(x_pred, is_fit=False)
except ValueError:
x_prepared = x_pred
x_prepared = self.prepare_dataset(x_pred, is_fit=False)
x_pred = x_prepared.to_pandas()
elif isinstance(x_pred, pd.DataFrame):
pass
Expand Down
69 changes: 42 additions & 27 deletions DashAI/back/services/scoring_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,12 @@ def _get_all_profiles(

return all_profiles

def get_available_profiles(self, task_name: Optional[str]) -> List[Dict[str, Any]]:
@inject
def get_available_profiles(
self,
task_name: Optional[str],
component_registry: "ComponentRegistry" = lambda di: di["component_registry"],
) -> List[Dict[str, Any]]:
"""Get profiles available for a given task.

Fetches profiles from the task's SCORING_PROFILES class attribute
Expand All @@ -101,39 +106,49 @@ def get_available_profiles(self, task_name: Optional[str]) -> List[Dict[str, Any
task_name : Optional[str]
Task class name (e.g., "TabularClassificationTask") or None.
If None, returns all profiles from all tasks.
component_registry : ComponentRegistry
Registry to look up task class (injected).

Returns
-------
List[Dict[str, Any]]
List of profile dicts with keys: id, description, weights.
"""
all_profiles = self._get_all_profiles()
if task_name:
# Access the task class directly to avoid profile-ID collisions
# that occur when multiple tasks share the same profile ID keys
# (e.g. TabularClassificationTask and ImageClassificationTask both
# define "balanced").
try:
full_task_dict = component_registry[task_name]
task_class = full_task_dict.get("class")
if task_class and hasattr(task_class, "SCORING_PROFILES"):
return [
{
"id": profile_id,
"description": profile_data["description"],
"weights": profile_data["weights"],
}
for profile_id, profile_data in (
task_class.SCORING_PROFILES.items()
)
]
except KeyError:
log.debug(f"Task {task_name} not found in registry")
except Exception as e:
log.warning(f"Error accessing SCORING_PROFILES for {task_name}: {e}")
return []

if not task_name:
# No task specified; return all profiles
return [
{
"id": profile_id,
"description": profile_data["description"],
"weights": profile_data["weights"],
}
for profile_id, profile_data in all_profiles.items()
]

# Filter profiles by task
result = []
for profile_id, profile_data in all_profiles.items():
if profile_data.get("task_name") == task_name:
result.append(
{
"id": profile_id,
"description": profile_data["description"],
"weights": profile_data["weights"],
}
)

# If no profiles found for this task, return empty (graceful fallback)
return result
# No task specified — return one entry per unique profile ID across all tasks.
all_profiles = self._get_all_profiles()
return [
{
"id": profile_id,
"description": profile_data["description"],
"weights": profile_data["weights"],
}
for profile_id, profile_data in all_profiles.items()
]

def normalize_metric_value(
self,
Expand Down
18 changes: 13 additions & 5 deletions DashAI/front/src/components/models/ModelComparisonTable.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,20 @@ function ModelComparisonTable({
fetchProfiles();
}, [session?.task_name]);

// Stable string that changes only when a run's status changes.
// Used as a dep so the score fetch re-triggers after training completes
// without firing on every unrelated re-render of the parent.
const runStatusSignature = useMemo(
() => initialRuns.map((r) => `${r.id}:${r.status}`).join(","),
[initialRuns],
);

// ────────────────────────────────────────────────────────────────────────
// Fetch scores when profile or runs change
// Fetch scores when profile, split, session or any run status changes
// ────────────────────────────────────────────────────────────────────────

useEffect(() => {
if (!runs.length || !selectedProfile || !session?.id) return;
if (!initialRuns.length || !selectedProfile || !session?.id) return;

const fetchScores = async () => {
setLoadingScores(true);
Expand Down Expand Up @@ -152,7 +160,7 @@ function ModelComparisonTable({
};

fetchScores();
}, [selectedProfile, metricSplit, session?.id]);
}, [selectedProfile, metricSplit, session?.id, runStatusSignature]);

// ────────────────────────────────────────────────────────────────────────
// Build columns
Expand Down Expand Up @@ -297,8 +305,8 @@ function ModelComparisonTable({
</Tooltip>
),
Cell: ({ row }) => {
const { statusCode, id } = row.original;
const isRunning = statusCode === 1 || statusCode === 2;
const { status, id } = row.original;
const isRunning = status === 1 || status === 2;
if (isRunning) return "-";

const scoreData = scores[id];
Expand Down
Loading