diff --git a/src/dispatch/entity/service.py b/src/dispatch/entity/service.py index 3628811c5797..17aa032667b2 100644 --- a/src/dispatch/entity/service.py +++ b/src/dispatch/entity/service.py @@ -84,17 +84,32 @@ def get_all_by_signal(*, db_session: Session, signal_id: int) -> list[Entity]: ) -def get_all_desc_by_signal(*, db_session: Session, signal_id: int) -> list[Entity]: - """Gets all entities for a specific signal in descending order.""" - return ( +def get_all_desc_by_signal( + *, db_session: Session, signal_id: int, case_id: int = None +) -> list[Entity]: + """Gets all entities for a specific signal in descending order. + + Args: + db_session: The database session. + signal_id: The ID of the signal to filter by. + case_id: Optional case ID to further filter the entities. + + Returns: + A list of entities ordered by creation date in descending order. + """ + query = ( db_session.query(Entity) .join(Entity.signal_instances) .join(SignalInstance.signal) .filter(Signal.id == signal_id) - .order_by(desc(Entity.created_at)) - .all() ) + if case_id is not None: + # Add case filter if case_id is provided + query = query.filter(SignalInstance.case_id == case_id) + + return query.order_by(desc(Entity.created_at)).all() + def create(*, db_session: Session, entity_in: EntityCreate) -> Entity: """Creates a new entity.""" diff --git a/src/dispatch/plugins/dispatch_slack/case/interactive.py b/src/dispatch/plugins/dispatch_slack/case/interactive.py index 403b5e8a9533..c4b16ef1ec23 100644 --- a/src/dispatch/plugins/dispatch_slack/case/interactive.py +++ b/src/dispatch/plugins/dispatch_slack/case/interactive.py @@ -722,11 +722,17 @@ def snooze_button_click( subject = context["subject"] + case_id = None if subject.type == SignalSubjects.signal_instance: instance = signal_service.get_signal_instance( db_session=db_session, signal_instance_id=subject.id ) subject.id = instance.signal.id + elif subject.type == CaseSubjects.case: + case = case_service.get(db_session=db_session, case_id=subject.id) + case_id = case.id + subject.type = SignalSubjects.signal_instance + subject.id = case.signal_instances[0].signal.id signal = signal_service.get(db_session=db_session, signal_id=subject.id) blocks = [ @@ -743,6 +749,7 @@ def snooze_button_click( db_session=db_session, signal_id=signal.id, optional=True, + case_id=case_id, ) if entity_select_block: diff --git a/src/dispatch/plugins/dispatch_slack/case/messages.py b/src/dispatch/plugins/dispatch_slack/case/messages.py index 262de7b13c81..4c5dc0aa3b56 100644 --- a/src/dispatch/plugins/dispatch_slack/case/messages.py +++ b/src/dispatch/plugins/dispatch_slack/case/messages.py @@ -263,7 +263,8 @@ def create_action_buttons_message( project_id=project_id, channel_id=channel_id, ).json() - mfa_button_metadata = SubjectMetadata( + + case_button_metadata = SubjectMetadata( type=CaseSubjects.case, organization_slug=organization_slug, id=case.id, @@ -288,12 +289,12 @@ def create_action_buttons_message( Button( text="💤 Snooze Alert", action_id=SignalNotificationActions.snooze, - value=button_metadata, + value=case_button_metadata, ), Button( text="👤 User MFA Challenge", action_id=CaseNotificationActions.user_mfa, - value=mfa_button_metadata, + value=case_button_metadata, ), ] ) diff --git a/src/dispatch/plugins/dispatch_slack/fields.py b/src/dispatch/plugins/dispatch_slack/fields.py index ad0efa955c29..bf967f5251b1 100644 --- a/src/dispatch/plugins/dispatch_slack/fields.py +++ b/src/dispatch/plugins/dispatch_slack/fields.py @@ -668,13 +668,14 @@ def entity_select( action_id: str = DefaultActionIds.entity_select, block_id: str = DefaultBlockIds.entity_select, label="Entities", + case_id: int = None, **kwargs, ): """Creates an entity select.""" entity_options = [ {"text": entity.value[:75], "value": entity.id} for entity in entity_service.get_all_desc_by_signal( - db_session=db_session, signal_id=signal_id + db_session=db_session, signal_id=signal_id, case_id=case_id ) if entity.value ] diff --git a/tests/entity/test_entity_service.py b/tests/entity/test_entity_service.py index 26ac5d1b2324..2f44989d951f 100644 --- a/tests/entity/test_entity_service.py +++ b/tests/entity/test_entity_service.py @@ -34,6 +34,53 @@ def test_get_all_by_signal(session, entity, signal_instance): assert entities[0].id == entity.id +def test_get_all_desc_by_signal_without_case(session, entity, signal_instance): + """ + Test get_all_desc_by_signal returns entities for a signal in descending order when case_id is not provided. + """ + from dispatch.entity.service import get_all_desc_by_signal + + # Associate the entity with the signal_instance + signal_instance.entities.append(entity) + session.add(signal_instance) + session.commit() + + signal_id = signal_instance.signal_id + + # Should return the entity, ordered by created_at desc + entities = get_all_desc_by_signal(db_session=session, signal_id=signal_id) + assert entity in entities + assert len(entities) == 1 + assert entities[0].id == entity.id + + +def test_get_all_desc_by_signal_with_case(session, entity, signal_instance): + """ + Test get_all_desc_by_signal returns entities for a signal and case in descending order when case_id is provided. + """ + from dispatch.entity.service import get_all_desc_by_signal + + # Attach the entity to the signal_instance and commit + signal_instance.entities.append(entity) + session.add(signal_instance) + session.commit() + + # The case is linked via signal_instance.case + case = signal_instance.case + signal_id = signal_instance.signal_id + case_id = case.id + + # Should return the entity, filtered by both signal and case + entities = get_all_desc_by_signal(db_session=session, signal_id=signal_id, case_id=case_id) + assert entity in entities + assert len(entities) == 1 + assert entities[0].id == entity.id + + # If we pass a non-matching case_id, should return empty + entities_none = get_all_desc_by_signal(db_session=session, signal_id=signal_id, case_id=case_id + 999) + assert len(entities_none) == 0 + + def test_create(session, entity_type, project): from dispatch.entity.models import EntityCreate from dispatch.entity.service import create