diff --git a/src/dispatch/entity_type/service.py b/src/dispatch/entity_type/service.py index c4dbe8775fec..60f7c3a4bb64 100644 --- a/src/dispatch/entity_type/service.py +++ b/src/dispatch/entity_type/service.py @@ -56,8 +56,10 @@ def get_all(*, db_session: Session, scope: str = None) -> Query: return db_session.query(EntityType) -def create(*, db_session: Session, entity_type_in: EntityTypeCreate) -> EntityType: - """Creates a new entity type.""" +def create( + *, db_session: Session, entity_type_in: EntityTypeCreate, case_id: int | None = None +) -> EntityType: + """Creates a new entity type and extracts entities from existing signal instances.""" project = project_service.get_by_name_or_raise( db_session=db_session, project_in=entity_type_in.project ) @@ -75,10 +77,35 @@ def create(*, db_session: Session, entity_type_in: EntityTypeCreate) -> EntityTy db_session.add(entity_type) db_session.commit() + + # Extract entities for all relevant signal instances + from dispatch.signal.models import SignalInstance + from dispatch.entity.service import find_entities + + if case_id: + # Get all signal instances for the case + signal_instances = ( + db_session.query(SignalInstance) + .filter(SignalInstance.case_id == case_id) + .limit(100) + .all() + ) + # Extract and create entities for these instances using only the new entity_type + for signal_instance in signal_instances: + new_entities = find_entities(db_session, signal_instance, [entity_type]) + # Associate new entities with the signal_instance + for entity in new_entities: + if entity not in signal_instance.entities: + signal_instance.entities.append(entity) + + db_session.commit() + return entity_type -def get_or_create(*, db_session: Session, entity_type_in: EntityTypeCreate) -> EntityType: +def get_or_create( + *, db_session: Session, entity_type_in: EntityTypeCreate, case_id: int | None = None +) -> EntityType: """Gets or creates a new entity type.""" q = ( db_session.query(EntityType) @@ -90,7 +117,7 @@ def get_or_create(*, db_session: Session, entity_type_in: EntityTypeCreate) -> E if instance: return instance - return create(db_session=db_session, entity_type_in=entity_type_in) + return create(db_session=db_session, entity_type_in=entity_type_in, case_id=case_id) def update( diff --git a/src/dispatch/entity_type/views.py b/src/dispatch/entity_type/views.py index acd090169231..ddf274949896 100644 --- a/src/dispatch/entity_type/views.py +++ b/src/dispatch/entity_type/views.py @@ -1,14 +1,14 @@ -from typing import Union +from typing import List from fastapi import APIRouter, HTTPException, status from pydantic.error_wrappers import ErrorWrapper, ValidationError from sqlalchemy.exc import IntegrityError +from dispatch.case.service import get as get_case from dispatch.database.core import DbSession from dispatch.exceptions import ExistsError from dispatch.database.service import CommonParameters, search_filter_sort_paginate from dispatch.models import PrimaryKey -from dispatch.signal.service import get_signal_instance from dispatch.signal.models import SignalInstanceRead from .models import ( @@ -54,11 +54,24 @@ def create_entity_type(db_session: DbSession, entity_type_in: EntityTypeCreate): return entity_type -@router.put("/recalculate/{entity_type_id}/{signal_instance_id}", response_model=SignalInstanceRead) -def recalculate( - db_session: DbSession, entity_type_id: PrimaryKey, signal_instance_id: Union[str, PrimaryKey] +@router.post("/{case_id}", response_model=EntityTypeRead) +def create_entity_type_with_case( + db_session: DbSession, case_id: PrimaryKey, entity_type_in: EntityTypeCreate ): - """Recalculates the associated entities for a signal instance.""" + """Create a new entity.""" + try: + entity_type = create(db_session=db_session, entity_type_in=entity_type_in, case_id=case_id) + except IntegrityError: + raise ValidationError( + [ErrorWrapper(ExistsError(msg="An entity with this name already exists."), loc="name")], + model=EntityTypeCreate, + ) from None + return entity_type + + +@router.put("/recalculate/{entity_type_id}/{case_id}", response_model=List[SignalInstanceRead]) +def recalculate(db_session: DbSession, entity_type_id: PrimaryKey, case_id: PrimaryKey): + """Recalculates the associated entities for all signal instances in a case.""" entity_type = get( db_session=db_session, entity_type_id=entity_type_id, @@ -69,21 +82,32 @@ def recalculate( detail=[{"msg": "An entity type with this id does not exist."}], ) - signal_instance = get_signal_instance( - db_session=db_session, - signal_instance_id=signal_instance_id, - ) - if not signal_instance: + case = get_case(db_session=db_session, case_id=case_id) + if not case: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=[{"msg": "A signal instance with this id does not exist."}], + detail=[{"msg": "A case with this id does not exist."}], ) - return recalculate_entity_flow( - db_session=db_session, - entity_type=entity_type, - signal_instance=signal_instance, - ) + # Get all signal instances associated with the case + signal_instances = case.signal_instances + if not signal_instances: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=[{"msg": "No signal instances found for this case."}], + ) + + # Recalculate entities for each signal instance + updated_signal_instances = [] + for signal_instance in signal_instances: + updated_signal_instance = recalculate_entity_flow( + db_session=db_session, + entity_type=entity_type, + signal_instance=signal_instance, + ) + updated_signal_instances.append(updated_signal_instance) + + return updated_signal_instances @router.put("/{entity_type_id}", response_model=EntityTypeRead) diff --git a/src/dispatch/static/dispatch/src/entity_type/EntityTypeCreateDialogV2.vue b/src/dispatch/static/dispatch/src/entity_type/EntityTypeCreateDialogV2.vue index 78eda3e07136..917606e8b092 100644 --- a/src/dispatch/static/dispatch/src/entity_type/EntityTypeCreateDialogV2.vue +++ b/src/dispatch/static/dispatch/src/entity_type/EntityTypeCreateDialogV2.vue @@ -60,7 +60,7 @@ :options="editorOptions" :editorMounted="editorMounted" language="json" - style="width: 100%; height: 100%" + style="width: 100%; height: 240px" /> @@ -330,7 +330,12 @@ const saveEntityType = async () => { signals: [signalGetResponse.data], } try { - const newEntityType = await EntityTypeApi.create(entityTypeData) + const newEntityType = await EntityTypeApi.create_with_case( + entityTypeData, + selectedCase.value.id + ) + // Use the case_id instead of signal_instance_id for recalculation + await EntityTypeApi.recalculate(newEntityType.data.id, selectedCase.value.id) emit("new-entity-type", newEntityType.data) store.commit( "notification_backend/addBeNotification", @@ -340,7 +345,6 @@ const saveEntityType = async () => { }, { root: true } ) - await EntityTypeApi.recalculate(newEntityType.data.id, props.signalObj.raw.id) } catch (error) { store.commit( "notification_backend/addBeNotification", diff --git a/src/dispatch/static/dispatch/src/entity_type/api.js b/src/dispatch/static/dispatch/src/entity_type/api.js index 260d94ddf597..c58b93bb8ea4 100644 --- a/src/dispatch/static/dispatch/src/entity_type/api.js +++ b/src/dispatch/static/dispatch/src/entity_type/api.js @@ -15,6 +15,10 @@ export default { return API.post(`${resource}`, payload) }, + create_with_case(payload, caseId) { + return API.post(`${resource}/${caseId}`, payload) + }, + update(entityTypeId, payload) { return API.put(`${resource}/${entityTypeId}`, payload) }, @@ -27,7 +31,7 @@ export default { return API.delete(`${resource}/${entityTypeId}`) }, - recalculate(entityTypeId, signalInstanceId) { - return API.put(`${resource}/recalculate/${entityTypeId}/${signalInstanceId}`) + recalculate(entityTypeId, caseId) { + return API.put(`${resource}/recalculate/${entityTypeId}/${caseId}`) }, } diff --git a/src/dispatch/static/dispatch/src/signal/NewRawSignalViewer.vue b/src/dispatch/static/dispatch/src/signal/NewRawSignalViewer.vue index f7bffe68670e..c329e6364e3e 100644 --- a/src/dispatch/static/dispatch/src/signal/NewRawSignalViewer.vue +++ b/src/dispatch/static/dispatch/src/signal/NewRawSignalViewer.vue @@ -1,5 +1,5 @@