Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions dojo/risk_acceptance/api.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
from abc import ABC, abstractmethod
from typing import NamedTuple

from django.core.exceptions import PermissionDenied
from django.db.models import QuerySet
from django.utils import timezone
from drf_spectacular.utils import extend_schema
from rest_framework import serializers, status
from rest_framework.decorators import action
from rest_framework.permissions import IsAdminUser
from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response

from dojo.api_v2.permissions import UserHasRiskAcceptanceRelatedObjectPermission
from dojo.api_v2.serializers import RiskAcceptanceSerializer
from dojo.authorization.roles_permissions import Permissions
from dojo.engagement.queries import get_authorized_engagements
from dojo.models import Risk_Acceptance, User, Vulnerability_Id
from dojo.models import Engagement, Risk_Acceptance, User, Vulnerability_Id

AcceptedRisk = NamedTuple("AcceptedRisk", (("vulnerability_id", str), ("justification", str), ("accepted_by", str)))

Expand Down Expand Up @@ -40,10 +42,13 @@ def risk_application_model_class(self):
request=AcceptedRiskSerializer(many=True),
responses={status.HTTP_201_CREATED: RiskAcceptanceSerializer(many=True)},
)
@action(methods=["post"], detail=True, permission_classes=[IsAdminUser], serializer_class=AcceptedRiskSerializer,
filter_backends=[], pagination_class=None)
@action(methods=["post"], detail=True, permission_classes=[IsAuthenticated, UserHasRiskAcceptanceRelatedObjectPermission],
serializer_class=AcceptedRiskSerializer, filter_backends=[], pagination_class=None)
def accept_risks(self, request, pk=None):
model = self.get_object()
product = model.product if isinstance(model, Engagement) else model.engagement.product
if not product.enable_full_risk_acceptance:
raise PermissionDenied
serializer = AcceptedRiskSerializer(data=request.data, many=True)
if serializer.is_valid():
accepted_risks = serializer.save()
Expand All @@ -63,7 +68,7 @@ class AcceptedFindingsMixin(ABC):
request=AcceptedRiskSerializer(many=True),
responses={status.HTTP_201_CREATED: RiskAcceptanceSerializer(many=True)},
)
@action(methods=["post"], detail=False, permission_classes=[IsAdminUser], serializer_class=AcceptedRiskSerializer)
@action(methods=["post"], detail=False, permission_classes=[IsAuthenticated], serializer_class=AcceptedRiskSerializer)
def accept_risks(self, request):
serializer = AcceptedRiskSerializer(data=request.data, many=True)
if serializer.is_valid():
Expand All @@ -72,7 +77,9 @@ def accept_risks(self, request):
return Response(data=serializer.errors, status=status.HTTP_400_BAD_REQUEST)
owner = request.user
accepted_result = []
for engagement in get_authorized_engagements(Permissions.Engagement_View):
for engagement in get_authorized_engagements(Permissions.Risk_Acceptance):
if not engagement.product.enable_full_risk_acceptance:
continue
base_findings = engagement.unaccepted_open_findings
accepted = _accept_risks(accepted_risks, base_findings, owner)
engagement.accept_risks(accepted)
Expand Down
188 changes: 188 additions & 0 deletions unittests/test_bulk_risk_acceptance_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Engagement,
Finding,
Product,
Product_Member,
Product_Type,
Product_Type_Member,
Role,
Expand Down Expand Up @@ -117,3 +118,190 @@ def test_finding_accept_risks(self):
for ra in self.engagement_2b.risk_acceptance.all():
for finding in ra.accepted_findings.all():
self.assertEqual(self.engagement_2a.product, finding.test.engagement.product)


class TestBulkRiskAcceptanceRbac(APITestCase):

"""Tests that accept_risks endpoints use RBAC (Permissions.Risk_Acceptance) instead of is_staff."""

@classmethod
def setUpTestData(cls):
cls.product_type = Product_Type.objects.create(name="RBAC Test Type")
cls.test_type = Test_Type.objects.create(name="RBAC Mock Scan", static_tool=True)

# Product with full risk acceptance enabled (default)
cls.product_enabled = Product.objects.create(
prod_type=cls.product_type, name="RBAC Enabled",
description="Full risk acceptance enabled",
enable_full_risk_acceptance=True,
)
# Product with full risk acceptance disabled
cls.product_disabled = Product.objects.create(
prod_type=cls.product_type, name="RBAC Disabled",
description="Full risk acceptance disabled",
enable_full_risk_acceptance=False,
)

cls.engagement_enabled = Engagement.objects.create(
product=cls.product_enabled,
target_start=datetime.datetime(2000, 1, 1, tzinfo=datetime.UTC),
target_end=datetime.datetime(2000, 2, 1, tzinfo=datetime.UTC),
)
cls.engagement_disabled = Engagement.objects.create(
product=cls.product_disabled,
target_start=datetime.datetime(2000, 1, 1, tzinfo=datetime.UTC),
target_end=datetime.datetime(2000, 2, 1, tzinfo=datetime.UTC),
)

cls.test_enabled = Test.objects.create(
engagement=cls.engagement_enabled, test_type=cls.test_type,
target_start=datetime.datetime(2000, 1, 1, tzinfo=datetime.UTC),
target_end=datetime.datetime(2000, 2, 1, tzinfo=datetime.UTC),
)
cls.test_disabled = Test.objects.create(
engagement=cls.engagement_disabled, test_type=cls.test_type,
target_start=datetime.datetime(2000, 1, 1, tzinfo=datetime.UTC),
target_end=datetime.datetime(2000, 2, 1, tzinfo=datetime.UTC),
)

# Writer user: has Risk_Acceptance permission, NOT is_staff
cls.writer = User.objects.create(username="rbac_writer", is_staff=False)
cls.writer_token = Token.objects.create(user=cls.writer)
Product_Member.objects.create(
product=cls.product_enabled, user=cls.writer,
role=Role.objects.get(id=Roles.Writer),
)
Product_Member.objects.create(
product=cls.product_disabled, user=cls.writer,
role=Role.objects.get(id=Roles.Writer),
)

# Reader user: does NOT have Risk_Acceptance permission, NOT is_staff
cls.reader = User.objects.create(username="rbac_reader", is_staff=False)
cls.reader_token = Token.objects.create(user=cls.reader)
Product_Member.objects.create(
product=cls.product_enabled, user=cls.reader,
role=Role.objects.get(id=Roles.Reader),
)

def create_finding(test, reporter, cve):
return Finding(
test=test, title=f"Finding {cve}", cve=cve, severity="High",
verified=True, description="Test", mitigation="Test",
impact="Test", reporter=reporter, numerical_severity="S1",
static_finding=True, dynamic_finding=False,
)

# Findings on the enabled product
Finding.objects.bulk_create(
create_finding(cls.test_enabled, cls.writer, f"CVE-2024-{i}") for i in range(10))
for f in Finding.objects.filter(test=cls.test_enabled):
Vulnerability_Id.objects.get_or_create(finding=f, vulnerability_id=f.cve)

# Findings on the disabled product
Finding.objects.bulk_create(
create_finding(cls.test_disabled, cls.writer, f"CVE-2024-{i + 100}") for i in range(5))
for f in Finding.objects.filter(test=cls.test_disabled):
Vulnerability_Id.objects.get_or_create(finding=f, vulnerability_id=f.cve)

def _client_for(self, token):
client = APIClient()
client.credentials(HTTP_AUTHORIZATION="Token " + token.key)
return client

def _accepted_risks(self, cve_ids):
return [{"vulnerability_id": cve, "justification": "Test", "accepted_by": "Tester"} for cve in cve_ids]

# --- Writer (has Risk_Acceptance) succeeds on enabled product ---

def test_writer_can_accept_risks_on_engagement(self):
client = self._client_for(self.writer_token)
result = client.post(
reverse("engagement-accept-risks", kwargs={"pk": self.engagement_enabled.id}),
data=self._accepted_risks(["CVE-2024-0"]),
format="json",
)
self.assertEqual(result.status_code, 201)

def test_writer_can_accept_risks_on_test(self):
client = self._client_for(self.writer_token)
result = client.post(
reverse("test-accept-risks", kwargs={"pk": self.test_enabled.id}),
data=self._accepted_risks(["CVE-2024-1"]),
format="json",
)
self.assertEqual(result.status_code, 201)

def test_writer_can_accept_risks_on_findings(self):
client = self._client_for(self.writer_token)
result = client.post(
reverse("finding-accept-risks"),
data=self._accepted_risks(["CVE-2024-2"]),
format="json",
)
self.assertEqual(result.status_code, 201)

# --- Reader (no Risk_Acceptance) is forbidden ---

def test_reader_forbidden_on_engagement(self):
client = self._client_for(self.reader_token)
result = client.post(
reverse("engagement-accept-risks", kwargs={"pk": self.engagement_enabled.id}),
data=self._accepted_risks(["CVE-2024-3"]),
format="json",
)
self.assertEqual(result.status_code, 403)

def test_reader_forbidden_on_test(self):
client = self._client_for(self.reader_token)
result = client.post(
reverse("test-accept-risks", kwargs={"pk": self.test_enabled.id}),
data=self._accepted_risks(["CVE-2024-4"]),
format="json",
)
self.assertEqual(result.status_code, 403)

def test_reader_gets_empty_result_on_findings(self):
client = self._client_for(self.reader_token)
result = client.post(
reverse("finding-accept-risks"),
data=self._accepted_risks(["CVE-2024-5"]),
format="json",
)
# Mass endpoint returns 201 with empty results (no authorized engagements)
self.assertEqual(result.status_code, 201)
self.assertEqual(len(result.json()), 0)

# --- enable_full_risk_acceptance=False blocks risk acceptance ---

def test_engagement_blocked_when_full_risk_acceptance_disabled(self):
client = self._client_for(self.writer_token)
result = client.post(
reverse("engagement-accept-risks", kwargs={"pk": self.engagement_disabled.id}),
data=self._accepted_risks(["CVE-2024-100"]),
format="json",
)
self.assertEqual(result.status_code, 403)

def test_test_blocked_when_full_risk_acceptance_disabled(self):
client = self._client_for(self.writer_token)
result = client.post(
reverse("test-accept-risks", kwargs={"pk": self.test_disabled.id}),
data=self._accepted_risks(["CVE-2024-101"]),
format="json",
)
self.assertEqual(result.status_code, 403)

def test_mass_endpoint_skips_disabled_products(self):
client = self._client_for(self.writer_token)
# Use a CVE that exists only on the disabled product
result = client.post(
reverse("finding-accept-risks"),
data=self._accepted_risks(["CVE-2024-102"]),
format="json",
)
self.assertEqual(result.status_code, 201)
# No risk acceptances created because the matching engagement's product has it disabled
self.assertEqual(len(result.json()), 0)
# Findings on disabled product remain unaccepted
self.assertEqual(self.engagement_disabled.unaccepted_open_findings.count(), 5)
Loading