"""Tests for core middleware""" import json import logging from io import StringIO from apps.core.middleware import ( ApiCsrfExemptMiddleware, ApiSlashlessRouteMiddleware, OrganizationApiMetricsMiddleware, RequestIDMiddleware, RequestLoggingMiddleware, get_request_id, ) from django.conf import settings from django.http import HttpResponse from django.test import RequestFactory from django.urls import reverse from organizations.models import Organization from rest_framework.test import APITestCase class RequestIDMiddlewareTest(APITestCase): """Tests for RequestIDMiddleware""" def test_request_id_generated(self): """Test that request ID is generated and returned in response header""" url = reverse("core:health") response = self.client.get(url) self.assertIn("X-Request-ID", response) self.assertIsNotNone(response["X-Request-ID"]) # UUID format check (36 chars with hyphens) self.assertEqual(len(response["X-Request-ID"]), 36) def test_request_id_passed_through(self): """Test that provided X-Request-ID is passed through""" url = reverse("core:health") custom_id = "custom-request-id-12345" response = self.client.get(url, HTTP_X_REQUEST_ID=custom_id) self.assertEqual(response["X-Request-ID"], custom_id) def test_different_requests_different_ids(self): """Test that different requests get different IDs""" url = reverse("core:health") response1 = self.client.get(url) response2 = self.client.get(url) self.assertNotEqual(response1["X-Request-ID"], response2["X-Request-ID"]) class RequestLoggingMiddlewareTest(APITestCase): def setUp(self): self.factory = RequestFactory() self.logger = logging.getLogger("apps.core.middleware") self.logger.setLevel(logging.INFO) self.stream = StringIO() handler = logging.StreamHandler(self.stream) self.logger.handlers = [] self.logger.addHandler(handler) def test_process_request_and_response_logs(self): middleware = RequestLoggingMiddleware(lambda req: HttpResponse(status=200)) request = self.factory.get("/health/") request.request_id = "req-123" middleware.process_request(request) response = middleware.process_response(request, HttpResponse(status=200)) self.assertEqual(response.status_code, 200) output = self.stream.getvalue() self.assertIn("Started", output) self.assertIn("200", output) def test_request_id_middleware_exception(self): middleware = RequestIDMiddleware(lambda req: None) request = self.factory.get("/health/") middleware.process_request(request) middleware.process_exception(request, RuntimeError("boom")) response = middleware.process_response(request, HttpResponse(status=200)) self.assertIn("X-Request-ID", response) self.assertIsNone(get_request_id()) class ApiCsrfExemptMiddlewareTest(APITestCase): def setUp(self): self.factory = RequestFactory() self.middleware = ApiCsrfExemptMiddleware(lambda req: HttpResponse(status=200)) def test_api_path_skips_csrf_check(self): request = self.factory.post("/api/v1/users/login/", data={}) response = self.middleware.process_view(request, lambda req: None, (), {}) self.assertIsNone(response) def test_non_api_path_keeps_csrf_check(self): request = self.factory.post("/admin/login/", data={}) response = self.middleware.process_view(request, lambda req: None, (), {}) self.assertEqual(response.status_code, 403) class ApiSlashlessRouteMiddlewareTest(APITestCase): def setUp(self): self.factory = RequestFactory() self.middleware = ApiSlashlessRouteMiddleware( lambda req: HttpResponse(status=200) ) def test_api_path_without_slash_rewrites_to_existing_slash_route(self): request = self.factory.post("/api/v1/users/login", data={}) response = self.middleware.process_request(request) self.assertIsNone(response) self.assertEqual(request.path_info, "/api/v1/users/login/") self.assertEqual(request.path, "/api/v1/users/login/") def test_unknown_api_path_without_slash_is_unchanged(self): request = self.factory.post("/api/v1/unknown-route", data={}) self.middleware.process_request(request) self.assertEqual(request.path_info, "/api/v1/unknown-route") def test_non_api_path_without_slash_is_unchanged(self): request = self.factory.post("/admin/login", data={}) self.middleware.process_request(request) self.assertEqual(request.path_info, "/admin/login") class OrganizationApiMetricsMiddlewareTest(APITestCase): def setUp(self): self.factory = RequestFactory() def test_middleware_is_enabled_for_organization_endpoint_metrics(self): self.assertIn( "apps.core.middleware.OrganizationApiMetricsMiddleware", settings.MIDDLEWARE, ) def test_logs_organization_endpoint_metrics_without_query_values(self): Organization.objects.create(name='ООО "Метрика"', inn="7707083893") def get_response(request): Organization.objects.count() response = HttpResponse("ok", status=200) response["X-Cache"] = "MISS" return response middleware = OrganizationApiMetricsMiddleware(get_response) request = self.factory.get( "/api/v2/organizations/", {"page": "1", "page_size": "20", "search": "7707083893"}, ) request.request_id = "metrics-request" request.user = type("AnonymousUser", (), {"is_authenticated": False})() with self.assertLogs("organizations.api.metrics", level="INFO") as captured: response = middleware(request) self.assertEqual(response.status_code, 200) self.assertEqual(len(captured.output), 1) _, payload = captured.output[0].split("organization_api_metrics ", 1) metrics = json.loads(payload) self.assertEqual(metrics["request_id"], "metrics-request") self.assertEqual(metrics["method"], "GET") self.assertEqual(metrics["path"], "/api/v2/organizations/") self.assertEqual(metrics["status_code"], 200) self.assertEqual(metrics["cache"], "MISS") self.assertEqual(metrics["query_keys"], ["page", "page_size", "search"]) self.assertGreaterEqual(metrics["db_query_count"], 1) self.assertGreater(metrics["duration_ms"], 0) self.assertGreater(metrics["response_size_bytes"], 0) self.assertNotIn("7707083893", captured.output[0]) def test_ignores_non_organization_api_paths(self): middleware = OrganizationApiMetricsMiddleware( lambda request: HttpResponse("ok", status=200) ) request = self.factory.get("/api/v2/sources/") request.request_id = "metrics-request" request.user = type("AnonymousUser", (), {"is_authenticated": False})() logger = logging.getLogger("organizations.api.metrics") stream = StringIO() handler = logging.StreamHandler(stream) logger.addHandler(handler) try: response = middleware(request) finally: logger.removeHandler(handler) self.assertEqual(response.status_code, 200) self.assertEqual(stream.getvalue(), "")