129 lines
4.6 KiB
Python
129 lines
4.6 KiB
Python
"""Tests for core middleware"""
|
|
|
|
import logging
|
|
from io import StringIO
|
|
|
|
from apps.core.middleware import (
|
|
ApiCsrfExemptMiddleware,
|
|
ApiSlashlessRouteMiddleware,
|
|
RequestIDMiddleware,
|
|
RequestLoggingMiddleware,
|
|
get_request_id,
|
|
)
|
|
from django.http import HttpResponse
|
|
from django.test import RequestFactory
|
|
from django.urls import reverse
|
|
from rest_framework.test import APITestCase
|
|
|
|
|
|
class RequestIDMiddlewareTest(APITestCase):
|
|
"""Tests for RequestIDMiddleware"""
|
|
|
|
def test_request_id_generated(self):
|
|
"""Test that request ID is generated and returned in response header"""
|
|
url = reverse("core:health")
|
|
response = self.client.get(url)
|
|
|
|
self.assertIn("X-Request-ID", response)
|
|
self.assertIsNotNone(response["X-Request-ID"])
|
|
# UUID format check (36 chars with hyphens)
|
|
self.assertEqual(len(response["X-Request-ID"]), 36)
|
|
|
|
def test_request_id_passed_through(self):
|
|
"""Test that provided X-Request-ID is passed through"""
|
|
url = reverse("core:health")
|
|
custom_id = "custom-request-id-12345"
|
|
response = self.client.get(url, HTTP_X_REQUEST_ID=custom_id)
|
|
|
|
self.assertEqual(response["X-Request-ID"], custom_id)
|
|
|
|
def test_different_requests_different_ids(self):
|
|
"""Test that different requests get different IDs"""
|
|
url = reverse("core:health")
|
|
response1 = self.client.get(url)
|
|
response2 = self.client.get(url)
|
|
|
|
self.assertNotEqual(response1["X-Request-ID"], response2["X-Request-ID"])
|
|
|
|
|
|
class RequestLoggingMiddlewareTest(APITestCase):
|
|
def setUp(self):
|
|
self.factory = RequestFactory()
|
|
self.logger = logging.getLogger("apps.core.middleware")
|
|
self.logger.setLevel(logging.INFO)
|
|
self.stream = StringIO()
|
|
handler = logging.StreamHandler(self.stream)
|
|
self.logger.handlers = []
|
|
self.logger.addHandler(handler)
|
|
|
|
def test_process_request_and_response_logs(self):
|
|
middleware = RequestLoggingMiddleware(lambda req: HttpResponse(status=200))
|
|
request = self.factory.get("/health/")
|
|
request.request_id = "req-123"
|
|
middleware.process_request(request)
|
|
response = middleware.process_response(request, HttpResponse(status=200))
|
|
self.assertEqual(response.status_code, 200)
|
|
output = self.stream.getvalue()
|
|
self.assertIn("Started", output)
|
|
self.assertIn("200", output)
|
|
|
|
def test_request_id_middleware_exception(self):
|
|
middleware = RequestIDMiddleware(lambda req: None)
|
|
request = self.factory.get("/health/")
|
|
middleware.process_request(request)
|
|
middleware.process_exception(request, RuntimeError("boom"))
|
|
response = middleware.process_response(request, HttpResponse(status=200))
|
|
self.assertIn("X-Request-ID", response)
|
|
self.assertIsNone(get_request_id())
|
|
|
|
|
|
class ApiCsrfExemptMiddlewareTest(APITestCase):
|
|
def setUp(self):
|
|
self.factory = RequestFactory()
|
|
self.middleware = ApiCsrfExemptMiddleware(lambda req: HttpResponse(status=200))
|
|
|
|
def test_api_path_skips_csrf_check(self):
|
|
request = self.factory.post("/api/v1/users/login/", data={})
|
|
|
|
response = self.middleware.process_view(request, lambda req: None, (), {})
|
|
|
|
self.assertIsNone(response)
|
|
|
|
def test_non_api_path_keeps_csrf_check(self):
|
|
request = self.factory.post("/admin/login/", data={})
|
|
|
|
response = self.middleware.process_view(request, lambda req: None, (), {})
|
|
|
|
self.assertEqual(response.status_code, 403)
|
|
|
|
|
|
class ApiSlashlessRouteMiddlewareTest(APITestCase):
|
|
def setUp(self):
|
|
self.factory = RequestFactory()
|
|
self.middleware = ApiSlashlessRouteMiddleware(
|
|
lambda req: HttpResponse(status=200)
|
|
)
|
|
|
|
def test_api_path_without_slash_rewrites_to_existing_slash_route(self):
|
|
request = self.factory.post("/api/v1/users/login", data={})
|
|
|
|
response = self.middleware.process_request(request)
|
|
|
|
self.assertIsNone(response)
|
|
self.assertEqual(request.path_info, "/api/v1/users/login/")
|
|
self.assertEqual(request.path, "/api/v1/users/login/")
|
|
|
|
def test_unknown_api_path_without_slash_is_unchanged(self):
|
|
request = self.factory.post("/api/v1/unknown-route", data={})
|
|
|
|
self.middleware.process_request(request)
|
|
|
|
self.assertEqual(request.path_info, "/api/v1/unknown-route")
|
|
|
|
def test_non_api_path_without_slash_is_unchanged(self):
|
|
request = self.factory.post("/admin/login", data={})
|
|
|
|
self.middleware.process_request(request)
|
|
|
|
self.assertEqual(request.path_info, "/admin/login")
|