200 lines
7.4 KiB
Python
200 lines
7.4 KiB
Python
"""Tests for core middleware"""
|
||
|
||
import json
|
||
import logging
|
||
from io import StringIO
|
||
|
||
from apps.core.middleware import (
|
||
ApiCsrfExemptMiddleware,
|
||
ApiSlashlessRouteMiddleware,
|
||
OrganizationApiMetricsMiddleware,
|
||
RequestIDMiddleware,
|
||
RequestLoggingMiddleware,
|
||
get_request_id,
|
||
)
|
||
from django.conf import settings
|
||
from django.http import HttpResponse
|
||
from django.test import RequestFactory
|
||
from django.urls import reverse
|
||
from organizations.models import Organization
|
||
from rest_framework.test import APITestCase
|
||
|
||
|
||
class RequestIDMiddlewareTest(APITestCase):
|
||
"""Tests for RequestIDMiddleware"""
|
||
|
||
def test_request_id_generated(self):
|
||
"""Test that request ID is generated and returned in response header"""
|
||
url = reverse("core:health")
|
||
response = self.client.get(url)
|
||
|
||
self.assertIn("X-Request-ID", response)
|
||
self.assertIsNotNone(response["X-Request-ID"])
|
||
# UUID format check (36 chars with hyphens)
|
||
self.assertEqual(len(response["X-Request-ID"]), 36)
|
||
|
||
def test_request_id_passed_through(self):
|
||
"""Test that provided X-Request-ID is passed through"""
|
||
url = reverse("core:health")
|
||
custom_id = "custom-request-id-12345"
|
||
response = self.client.get(url, HTTP_X_REQUEST_ID=custom_id)
|
||
|
||
self.assertEqual(response["X-Request-ID"], custom_id)
|
||
|
||
def test_different_requests_different_ids(self):
|
||
"""Test that different requests get different IDs"""
|
||
url = reverse("core:health")
|
||
response1 = self.client.get(url)
|
||
response2 = self.client.get(url)
|
||
|
||
self.assertNotEqual(response1["X-Request-ID"], response2["X-Request-ID"])
|
||
|
||
|
||
class RequestLoggingMiddlewareTest(APITestCase):
|
||
def setUp(self):
|
||
self.factory = RequestFactory()
|
||
self.logger = logging.getLogger("apps.core.middleware")
|
||
self.logger.setLevel(logging.INFO)
|
||
self.stream = StringIO()
|
||
handler = logging.StreamHandler(self.stream)
|
||
self.logger.handlers = []
|
||
self.logger.addHandler(handler)
|
||
|
||
def test_process_request_and_response_logs(self):
|
||
middleware = RequestLoggingMiddleware(lambda req: HttpResponse(status=200))
|
||
request = self.factory.get("/health/")
|
||
request.request_id = "req-123"
|
||
middleware.process_request(request)
|
||
response = middleware.process_response(request, HttpResponse(status=200))
|
||
self.assertEqual(response.status_code, 200)
|
||
output = self.stream.getvalue()
|
||
self.assertIn("Started", output)
|
||
self.assertIn("200", output)
|
||
|
||
def test_request_id_middleware_exception(self):
|
||
middleware = RequestIDMiddleware(lambda req: None)
|
||
request = self.factory.get("/health/")
|
||
middleware.process_request(request)
|
||
middleware.process_exception(request, RuntimeError("boom"))
|
||
response = middleware.process_response(request, HttpResponse(status=200))
|
||
self.assertIn("X-Request-ID", response)
|
||
self.assertIsNone(get_request_id())
|
||
|
||
|
||
class ApiCsrfExemptMiddlewareTest(APITestCase):
|
||
def setUp(self):
|
||
self.factory = RequestFactory()
|
||
self.middleware = ApiCsrfExemptMiddleware(lambda req: HttpResponse(status=200))
|
||
|
||
def test_api_path_skips_csrf_check(self):
|
||
request = self.factory.post("/api/v1/users/login/", data={})
|
||
|
||
response = self.middleware.process_view(request, lambda req: None, (), {})
|
||
|
||
self.assertIsNone(response)
|
||
|
||
def test_non_api_path_keeps_csrf_check(self):
|
||
request = self.factory.post("/admin/login/", data={})
|
||
|
||
response = self.middleware.process_view(request, lambda req: None, (), {})
|
||
|
||
self.assertEqual(response.status_code, 403)
|
||
|
||
|
||
class ApiSlashlessRouteMiddlewareTest(APITestCase):
|
||
def setUp(self):
|
||
self.factory = RequestFactory()
|
||
self.middleware = ApiSlashlessRouteMiddleware(
|
||
lambda req: HttpResponse(status=200)
|
||
)
|
||
|
||
def test_api_path_without_slash_rewrites_to_existing_slash_route(self):
|
||
request = self.factory.post("/api/v1/users/login", data={})
|
||
|
||
response = self.middleware.process_request(request)
|
||
|
||
self.assertIsNone(response)
|
||
self.assertEqual(request.path_info, "/api/v1/users/login/")
|
||
self.assertEqual(request.path, "/api/v1/users/login/")
|
||
|
||
def test_unknown_api_path_without_slash_is_unchanged(self):
|
||
request = self.factory.post("/api/v1/unknown-route", data={})
|
||
|
||
self.middleware.process_request(request)
|
||
|
||
self.assertEqual(request.path_info, "/api/v1/unknown-route")
|
||
|
||
def test_non_api_path_without_slash_is_unchanged(self):
|
||
request = self.factory.post("/admin/login", data={})
|
||
|
||
self.middleware.process_request(request)
|
||
|
||
self.assertEqual(request.path_info, "/admin/login")
|
||
|
||
|
||
class OrganizationApiMetricsMiddlewareTest(APITestCase):
|
||
def setUp(self):
|
||
self.factory = RequestFactory()
|
||
|
||
def test_middleware_is_enabled_for_organization_endpoint_metrics(self):
|
||
self.assertIn(
|
||
"apps.core.middleware.OrganizationApiMetricsMiddleware",
|
||
settings.MIDDLEWARE,
|
||
)
|
||
|
||
def test_logs_organization_endpoint_metrics_without_query_values(self):
|
||
Organization.objects.create(name='ООО "Метрика"', inn="7707083893")
|
||
|
||
def get_response(request):
|
||
Organization.objects.count()
|
||
response = HttpResponse("ok", status=200)
|
||
response["X-Cache"] = "MISS"
|
||
return response
|
||
|
||
middleware = OrganizationApiMetricsMiddleware(get_response)
|
||
request = self.factory.get(
|
||
"/api/v2/organizations/",
|
||
{"page": "1", "page_size": "20", "search": "7707083893"},
|
||
)
|
||
request.request_id = "metrics-request"
|
||
request.user = type("AnonymousUser", (), {"is_authenticated": False})()
|
||
|
||
with self.assertLogs("organizations.api.metrics", level="INFO") as captured:
|
||
response = middleware(request)
|
||
|
||
self.assertEqual(response.status_code, 200)
|
||
self.assertEqual(len(captured.output), 1)
|
||
_, payload = captured.output[0].split("organization_api_metrics ", 1)
|
||
metrics = json.loads(payload)
|
||
|
||
self.assertEqual(metrics["request_id"], "metrics-request")
|
||
self.assertEqual(metrics["method"], "GET")
|
||
self.assertEqual(metrics["path"], "/api/v2/organizations/")
|
||
self.assertEqual(metrics["status_code"], 200)
|
||
self.assertEqual(metrics["cache"], "MISS")
|
||
self.assertEqual(metrics["query_keys"], ["page", "page_size", "search"])
|
||
self.assertGreaterEqual(metrics["db_query_count"], 1)
|
||
self.assertGreater(metrics["duration_ms"], 0)
|
||
self.assertGreater(metrics["response_size_bytes"], 0)
|
||
self.assertNotIn("7707083893", captured.output[0])
|
||
|
||
def test_ignores_non_organization_api_paths(self):
|
||
middleware = OrganizationApiMetricsMiddleware(
|
||
lambda request: HttpResponse("ok", status=200)
|
||
)
|
||
request = self.factory.get("/api/v2/sources/")
|
||
request.request_id = "metrics-request"
|
||
request.user = type("AnonymousUser", (), {"is_authenticated": False})()
|
||
|
||
logger = logging.getLogger("organizations.api.metrics")
|
||
stream = StringIO()
|
||
handler = logging.StreamHandler(stream)
|
||
logger.addHandler(handler)
|
||
try:
|
||
response = middleware(request)
|
||
finally:
|
||
logger.removeHandler(handler)
|
||
|
||
self.assertEqual(response.status_code, 200)
|
||
self.assertEqual(stream.getvalue(), "")
|