From 90856d5a7e1f683c9a7b033a904dd9a71259c71d Mon Sep 17 00:00:00 2001 From: Aleksandr Meshchriakov Date: Wed, 29 Apr 2026 12:09:56 +0200 Subject: [PATCH] fix(api): disable csrf checks for api routes --- src/apps/core/middleware.py | 12 ++++++++++++ src/settings/base.py | 2 +- src/settings/dev.py | 5 ----- tests/apps/core/test_middleware.py | 21 +++++++++++++++++++++ 4 files changed, 34 insertions(+), 6 deletions(-) diff --git a/src/apps/core/middleware.py b/src/apps/core/middleware.py index 8effb6e..ed88e52 100644 --- a/src/apps/core/middleware.py +++ b/src/apps/core/middleware.py @@ -8,6 +8,7 @@ import logging import threading import uuid +from django.middleware.csrf import CsrfViewMiddleware from django.utils.deprecation import MiddlewareMixin logger = logging.getLogger(__name__) @@ -16,6 +17,17 @@ logger = logging.getLogger(__name__) _request_context = threading.local() +class ApiCsrfExemptMiddleware(CsrfViewMiddleware): + """Skip CSRF checks for JWT/DRF API routes while keeping CSRF elsewhere.""" + + api_prefixes = ("/api/",) + + def process_view(self, request, callback, callback_args, callback_kwargs): + if request.path_info.startswith(self.api_prefixes): + return None + return super().process_view(request, callback, callback_args, callback_kwargs) + + def get_request_id() -> str | None: """Get current request ID from thread-local storage.""" return getattr(_request_context, "request_id", None) diff --git a/src/settings/base.py b/src/settings/base.py index 1198f95..ef9e260 100644 --- a/src/settings/base.py +++ b/src/settings/base.py @@ -167,7 +167,7 @@ MIDDLEWARE = [ "whitenoise.middleware.WhiteNoiseMiddleware", "django.contrib.sessions.middleware.SessionMiddleware", "django.middleware.common.CommonMiddleware", - "django.middleware.csrf.CsrfViewMiddleware", + "apps.core.middleware.ApiCsrfExemptMiddleware", "django.contrib.auth.middleware.AuthenticationMiddleware", "django.contrib.messages.middleware.MessageMiddleware", "django.middleware.clickjacking.XFrameOptionsMiddleware", diff --git a/src/settings/dev.py b/src/settings/dev.py index 2cd49f9..adf9c7e 100644 --- a/src/settings/dev.py +++ b/src/settings/dev.py @@ -29,11 +29,6 @@ CORS_ALLOW_CREDENTIALS = True CORS_ALLOW_PRIVATE_NETWORK = True CSRF_COOKIE_SECURE = False SESSION_COOKIE_SECURE = False -MIDDLEWARE = [ - middleware - for middleware in MIDDLEWARE - if middleware != "django.middleware.csrf.CsrfViewMiddleware" -] def _normalize_local_host(host: str) -> str: diff --git a/tests/apps/core/test_middleware.py b/tests/apps/core/test_middleware.py index 942885f..421d009 100644 --- a/tests/apps/core/test_middleware.py +++ b/tests/apps/core/test_middleware.py @@ -4,6 +4,7 @@ import logging from io import StringIO from apps.core.middleware import ( + ApiCsrfExemptMiddleware, RequestIDMiddleware, RequestLoggingMiddleware, get_request_id, @@ -73,3 +74,23 @@ class RequestLoggingMiddlewareTest(APITestCase): response = middleware.process_response(request, HttpResponse(status=200)) self.assertIn("X-Request-ID", response) self.assertIsNone(get_request_id()) + + +class ApiCsrfExemptMiddlewareTest(APITestCase): + def setUp(self): + self.factory = RequestFactory() + self.middleware = ApiCsrfExemptMiddleware(lambda req: HttpResponse(status=200)) + + def test_api_path_skips_csrf_check(self): + request = self.factory.post("/api/v1/users/login/", data={}) + + response = self.middleware.process_view(request, lambda req: None, (), {}) + + self.assertIsNone(response) + + def test_non_api_path_keeps_csrf_check(self): + request = self.factory.post("/admin/login/", data={}) + + response = self.middleware.process_view(request, lambda req: None, (), {}) + + self.assertEqual(response.status_code, 403)