"""Tests for core middleware""" import logging from io import StringIO from apps.core.middleware import ( ApiCsrfExemptMiddleware, ApiSlashlessRouteMiddleware, RequestIDMiddleware, RequestLoggingMiddleware, get_request_id, ) from django.http import HttpResponse from django.test import RequestFactory from django.urls import reverse from rest_framework.test import APITestCase class RequestIDMiddlewareTest(APITestCase): """Tests for RequestIDMiddleware""" def test_request_id_generated(self): """Test that request ID is generated and returned in response header""" url = reverse("core:health") response = self.client.get(url) self.assertIn("X-Request-ID", response) self.assertIsNotNone(response["X-Request-ID"]) # UUID format check (36 chars with hyphens) self.assertEqual(len(response["X-Request-ID"]), 36) def test_request_id_passed_through(self): """Test that provided X-Request-ID is passed through""" url = reverse("core:health") custom_id = "custom-request-id-12345" response = self.client.get(url, HTTP_X_REQUEST_ID=custom_id) self.assertEqual(response["X-Request-ID"], custom_id) def test_different_requests_different_ids(self): """Test that different requests get different IDs""" url = reverse("core:health") response1 = self.client.get(url) response2 = self.client.get(url) self.assertNotEqual(response1["X-Request-ID"], response2["X-Request-ID"]) class RequestLoggingMiddlewareTest(APITestCase): def setUp(self): self.factory = RequestFactory() self.logger = logging.getLogger("apps.core.middleware") self.logger.setLevel(logging.INFO) self.stream = StringIO() handler = logging.StreamHandler(self.stream) self.logger.handlers = [] self.logger.addHandler(handler) def test_process_request_and_response_logs(self): middleware = RequestLoggingMiddleware(lambda req: HttpResponse(status=200)) request = self.factory.get("/health/") request.request_id = "req-123" middleware.process_request(request) response = middleware.process_response(request, HttpResponse(status=200)) self.assertEqual(response.status_code, 200) output = self.stream.getvalue() self.assertIn("Started", output) self.assertIn("200", output) def test_request_id_middleware_exception(self): middleware = RequestIDMiddleware(lambda req: None) request = self.factory.get("/health/") middleware.process_request(request) middleware.process_exception(request, RuntimeError("boom")) response = middleware.process_response(request, HttpResponse(status=200)) self.assertIn("X-Request-ID", response) self.assertIsNone(get_request_id()) class ApiCsrfExemptMiddlewareTest(APITestCase): def setUp(self): self.factory = RequestFactory() self.middleware = ApiCsrfExemptMiddleware(lambda req: HttpResponse(status=200)) def test_api_path_skips_csrf_check(self): request = self.factory.post("/api/v1/users/login/", data={}) response = self.middleware.process_view(request, lambda req: None, (), {}) self.assertIsNone(response) def test_non_api_path_keeps_csrf_check(self): request = self.factory.post("/admin/login/", data={}) response = self.middleware.process_view(request, lambda req: None, (), {}) self.assertEqual(response.status_code, 403) class ApiSlashlessRouteMiddlewareTest(APITestCase): def setUp(self): self.factory = RequestFactory() self.middleware = ApiSlashlessRouteMiddleware( lambda req: HttpResponse(status=200) ) def test_api_path_without_slash_rewrites_to_existing_slash_route(self): request = self.factory.post("/api/v1/users/login", data={}) response = self.middleware.process_request(request) self.assertIsNone(response) self.assertEqual(request.path_info, "/api/v1/users/login/") self.assertEqual(request.path, "/api/v1/users/login/") def test_unknown_api_path_without_slash_is_unchanged(self): request = self.factory.post("/api/v1/unknown-route", data={}) self.middleware.process_request(request) self.assertEqual(request.path_info, "/api/v1/unknown-route") def test_non_api_path_without_slash_is_unchanged(self): request = self.factory.post("/admin/login", data={}) self.middleware.process_request(request) self.assertEqual(request.path_info, "/admin/login")