Skip to content
This repository was archived by the owner on Sep 3, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions src/dispatch/entity/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,17 +84,32 @@ def get_all_by_signal(*, db_session: Session, signal_id: int) -> list[Entity]:
)


def get_all_desc_by_signal(*, db_session: Session, signal_id: int) -> list[Entity]:
"""Gets all entities for a specific signal in descending order."""
return (
def get_all_desc_by_signal(
*, db_session: Session, signal_id: int, case_id: int = None
) -> list[Entity]:
"""Gets all entities for a specific signal in descending order.

Args:
db_session: The database session.
signal_id: The ID of the signal to filter by.
case_id: Optional case ID to further filter the entities.

Returns:
A list of entities ordered by creation date in descending order.
"""
query = (
db_session.query(Entity)
.join(Entity.signal_instances)
.join(SignalInstance.signal)
.filter(Signal.id == signal_id)
.order_by(desc(Entity.created_at))
.all()
)

if case_id is not None:
# Add case filter if case_id is provided
query = query.filter(SignalInstance.case_id == case_id)

return query.order_by(desc(Entity.created_at)).all()


def create(*, db_session: Session, entity_in: EntityCreate) -> Entity:
"""Creates a new entity."""
Expand Down
7 changes: 7 additions & 0 deletions src/dispatch/plugins/dispatch_slack/case/interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,11 +722,17 @@ def snooze_button_click(

subject = context["subject"]

case_id = None
if subject.type == SignalSubjects.signal_instance:
instance = signal_service.get_signal_instance(
db_session=db_session, signal_instance_id=subject.id
)
subject.id = instance.signal.id
elif subject.type == CaseSubjects.case:
case = case_service.get(db_session=db_session, case_id=subject.id)
case_id = case.id
subject.type = SignalSubjects.signal_instance
subject.id = case.signal_instances[0].signal.id

signal = signal_service.get(db_session=db_session, signal_id=subject.id)
blocks = [
Expand All @@ -743,6 +749,7 @@ def snooze_button_click(
db_session=db_session,
signal_id=signal.id,
optional=True,
case_id=case_id,
)

if entity_select_block:
Expand Down
7 changes: 4 additions & 3 deletions src/dispatch/plugins/dispatch_slack/case/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,8 @@ def create_action_buttons_message(
project_id=project_id,
channel_id=channel_id,
).json()
mfa_button_metadata = SubjectMetadata(

case_button_metadata = SubjectMetadata(
type=CaseSubjects.case,
organization_slug=organization_slug,
id=case.id,
Expand All @@ -288,12 +289,12 @@ def create_action_buttons_message(
Button(
text="💤 Snooze Alert",
action_id=SignalNotificationActions.snooze,
value=button_metadata,
value=case_button_metadata,
),
Button(
text="👤 User MFA Challenge",
action_id=CaseNotificationActions.user_mfa,
value=mfa_button_metadata,
value=case_button_metadata,
),
]
)
Expand Down
3 changes: 2 additions & 1 deletion src/dispatch/plugins/dispatch_slack/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,13 +668,14 @@ def entity_select(
action_id: str = DefaultActionIds.entity_select,
block_id: str = DefaultBlockIds.entity_select,
label="Entities",
case_id: int = None,
**kwargs,
):
"""Creates an entity select."""
entity_options = [
{"text": entity.value[:75], "value": entity.id}
for entity in entity_service.get_all_desc_by_signal(
db_session=db_session, signal_id=signal_id
db_session=db_session, signal_id=signal_id, case_id=case_id
)
if entity.value
]
Expand Down
47 changes: 47 additions & 0 deletions tests/entity/test_entity_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,53 @@ def test_get_all_by_signal(session, entity, signal_instance):
assert entities[0].id == entity.id


def test_get_all_desc_by_signal_without_case(session, entity, signal_instance):
"""
Test get_all_desc_by_signal returns entities for a signal in descending order when case_id is not provided.
"""
from dispatch.entity.service import get_all_desc_by_signal

# Associate the entity with the signal_instance
signal_instance.entities.append(entity)
session.add(signal_instance)
session.commit()

signal_id = signal_instance.signal_id

# Should return the entity, ordered by created_at desc
entities = get_all_desc_by_signal(db_session=session, signal_id=signal_id)
assert entity in entities
assert len(entities) == 1
assert entities[0].id == entity.id


def test_get_all_desc_by_signal_with_case(session, entity, signal_instance):
"""
Test get_all_desc_by_signal returns entities for a signal and case in descending order when case_id is provided.
"""
from dispatch.entity.service import get_all_desc_by_signal

# Attach the entity to the signal_instance and commit
signal_instance.entities.append(entity)
session.add(signal_instance)
session.commit()

# The case is linked via signal_instance.case
case = signal_instance.case
signal_id = signal_instance.signal_id
case_id = case.id

# Should return the entity, filtered by both signal and case
entities = get_all_desc_by_signal(db_session=session, signal_id=signal_id, case_id=case_id)
assert entity in entities
assert len(entities) == 1
assert entities[0].id == entity.id

# If we pass a non-matching case_id, should return empty
entities_none = get_all_desc_by_signal(db_session=session, signal_id=signal_id, case_id=case_id + 999)
assert len(entities_none) == 0


def test_create(session, entity_type, project):
from dispatch.entity.models import EntityCreate
from dispatch.entity.service import create
Expand Down
Loading