Files
mostovik-backend/tests/apps/core/test_middleware.py

200 lines
7.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Tests for core middleware"""
import json
import logging
from io import StringIO
from apps.core.middleware import (
ApiCsrfExemptMiddleware,
ApiSlashlessRouteMiddleware,
OrganizationApiMetricsMiddleware,
RequestIDMiddleware,
RequestLoggingMiddleware,
get_request_id,
)
from django.conf import settings
from django.http import HttpResponse
from django.test import RequestFactory
from django.urls import reverse
from organizations.models import Organization
from rest_framework.test import APITestCase
class RequestIDMiddlewareTest(APITestCase):
"""Tests for RequestIDMiddleware"""
def test_request_id_generated(self):
"""Test that request ID is generated and returned in response header"""
url = reverse("core:health")
response = self.client.get(url)
self.assertIn("X-Request-ID", response)
self.assertIsNotNone(response["X-Request-ID"])
# UUID format check (36 chars with hyphens)
self.assertEqual(len(response["X-Request-ID"]), 36)
def test_request_id_passed_through(self):
"""Test that provided X-Request-ID is passed through"""
url = reverse("core:health")
custom_id = "custom-request-id-12345"
response = self.client.get(url, HTTP_X_REQUEST_ID=custom_id)
self.assertEqual(response["X-Request-ID"], custom_id)
def test_different_requests_different_ids(self):
"""Test that different requests get different IDs"""
url = reverse("core:health")
response1 = self.client.get(url)
response2 = self.client.get(url)
self.assertNotEqual(response1["X-Request-ID"], response2["X-Request-ID"])
class RequestLoggingMiddlewareTest(APITestCase):
def setUp(self):
self.factory = RequestFactory()
self.logger = logging.getLogger("apps.core.middleware")
self.logger.setLevel(logging.INFO)
self.stream = StringIO()
handler = logging.StreamHandler(self.stream)
self.logger.handlers = []
self.logger.addHandler(handler)
def test_process_request_and_response_logs(self):
middleware = RequestLoggingMiddleware(lambda req: HttpResponse(status=200))
request = self.factory.get("/health/")
request.request_id = "req-123"
middleware.process_request(request)
response = middleware.process_response(request, HttpResponse(status=200))
self.assertEqual(response.status_code, 200)
output = self.stream.getvalue()
self.assertIn("Started", output)
self.assertIn("200", output)
def test_request_id_middleware_exception(self):
middleware = RequestIDMiddleware(lambda req: None)
request = self.factory.get("/health/")
middleware.process_request(request)
middleware.process_exception(request, RuntimeError("boom"))
response = middleware.process_response(request, HttpResponse(status=200))
self.assertIn("X-Request-ID", response)
self.assertIsNone(get_request_id())
class ApiCsrfExemptMiddlewareTest(APITestCase):
def setUp(self):
self.factory = RequestFactory()
self.middleware = ApiCsrfExemptMiddleware(lambda req: HttpResponse(status=200))
def test_api_path_skips_csrf_check(self):
request = self.factory.post("/api/v1/users/login/", data={})
response = self.middleware.process_view(request, lambda req: None, (), {})
self.assertIsNone(response)
def test_non_api_path_keeps_csrf_check(self):
request = self.factory.post("/admin/login/", data={})
response = self.middleware.process_view(request, lambda req: None, (), {})
self.assertEqual(response.status_code, 403)
class ApiSlashlessRouteMiddlewareTest(APITestCase):
def setUp(self):
self.factory = RequestFactory()
self.middleware = ApiSlashlessRouteMiddleware(
lambda req: HttpResponse(status=200)
)
def test_api_path_without_slash_rewrites_to_existing_slash_route(self):
request = self.factory.post("/api/v1/users/login", data={})
response = self.middleware.process_request(request)
self.assertIsNone(response)
self.assertEqual(request.path_info, "/api/v1/users/login/")
self.assertEqual(request.path, "/api/v1/users/login/")
def test_unknown_api_path_without_slash_is_unchanged(self):
request = self.factory.post("/api/v1/unknown-route", data={})
self.middleware.process_request(request)
self.assertEqual(request.path_info, "/api/v1/unknown-route")
def test_non_api_path_without_slash_is_unchanged(self):
request = self.factory.post("/admin/login", data={})
self.middleware.process_request(request)
self.assertEqual(request.path_info, "/admin/login")
class OrganizationApiMetricsMiddlewareTest(APITestCase):
def setUp(self):
self.factory = RequestFactory()
def test_middleware_is_enabled_for_organization_endpoint_metrics(self):
self.assertIn(
"apps.core.middleware.OrganizationApiMetricsMiddleware",
settings.MIDDLEWARE,
)
def test_logs_organization_endpoint_metrics_without_query_values(self):
Organization.objects.create(name='ООО "Метрика"', inn="7707083893")
def get_response(request):
Organization.objects.count()
response = HttpResponse("ok", status=200)
response["X-Cache"] = "MISS"
return response
middleware = OrganizationApiMetricsMiddleware(get_response)
request = self.factory.get(
"/api/v2/organizations/",
{"page": "1", "page_size": "20", "search": "7707083893"},
)
request.request_id = "metrics-request"
request.user = type("AnonymousUser", (), {"is_authenticated": False})()
with self.assertLogs("organizations.api.metrics", level="INFO") as captured:
response = middleware(request)
self.assertEqual(response.status_code, 200)
self.assertEqual(len(captured.output), 1)
_, payload = captured.output[0].split("organization_api_metrics ", 1)
metrics = json.loads(payload)
self.assertEqual(metrics["request_id"], "metrics-request")
self.assertEqual(metrics["method"], "GET")
self.assertEqual(metrics["path"], "/api/v2/organizations/")
self.assertEqual(metrics["status_code"], 200)
self.assertEqual(metrics["cache"], "MISS")
self.assertEqual(metrics["query_keys"], ["page", "page_size", "search"])
self.assertGreaterEqual(metrics["db_query_count"], 1)
self.assertGreater(metrics["duration_ms"], 0)
self.assertGreater(metrics["response_size_bytes"], 0)
self.assertNotIn("7707083893", captured.output[0])
def test_ignores_non_organization_api_paths(self):
middleware = OrganizationApiMetricsMiddleware(
lambda request: HttpResponse("ok", status=200)
)
request = self.factory.get("/api/v2/sources/")
request.request_id = "metrics-request"
request.user = type("AnonymousUser", (), {"is_authenticated": False})()
logger = logging.getLogger("organizations.api.metrics")
stream = StringIO()
handler = logging.StreamHandler(stream)
logger.addHandler(handler)
try:
response = middleware(request)
finally:
logger.removeHandler(handler)
self.assertEqual(response.status_code, 200)
self.assertEqual(stream.getvalue(), "")