fix(api): disable csrf checks for api routes
This commit is contained in:
@@ -8,6 +8,7 @@ import logging
|
|||||||
import threading
|
import threading
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
from django.middleware.csrf import CsrfViewMiddleware
|
||||||
from django.utils.deprecation import MiddlewareMixin
|
from django.utils.deprecation import MiddlewareMixin
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -16,6 +17,17 @@ logger = logging.getLogger(__name__)
|
|||||||
_request_context = threading.local()
|
_request_context = threading.local()
|
||||||
|
|
||||||
|
|
||||||
|
class ApiCsrfExemptMiddleware(CsrfViewMiddleware):
|
||||||
|
"""Skip CSRF checks for JWT/DRF API routes while keeping CSRF elsewhere."""
|
||||||
|
|
||||||
|
api_prefixes = ("/api/",)
|
||||||
|
|
||||||
|
def process_view(self, request, callback, callback_args, callback_kwargs):
|
||||||
|
if request.path_info.startswith(self.api_prefixes):
|
||||||
|
return None
|
||||||
|
return super().process_view(request, callback, callback_args, callback_kwargs)
|
||||||
|
|
||||||
|
|
||||||
def get_request_id() -> str | None:
|
def get_request_id() -> str | None:
|
||||||
"""Get current request ID from thread-local storage."""
|
"""Get current request ID from thread-local storage."""
|
||||||
return getattr(_request_context, "request_id", None)
|
return getattr(_request_context, "request_id", None)
|
||||||
|
|||||||
@@ -167,7 +167,7 @@ MIDDLEWARE = [
|
|||||||
"whitenoise.middleware.WhiteNoiseMiddleware",
|
"whitenoise.middleware.WhiteNoiseMiddleware",
|
||||||
"django.contrib.sessions.middleware.SessionMiddleware",
|
"django.contrib.sessions.middleware.SessionMiddleware",
|
||||||
"django.middleware.common.CommonMiddleware",
|
"django.middleware.common.CommonMiddleware",
|
||||||
"django.middleware.csrf.CsrfViewMiddleware",
|
"apps.core.middleware.ApiCsrfExemptMiddleware",
|
||||||
"django.contrib.auth.middleware.AuthenticationMiddleware",
|
"django.contrib.auth.middleware.AuthenticationMiddleware",
|
||||||
"django.contrib.messages.middleware.MessageMiddleware",
|
"django.contrib.messages.middleware.MessageMiddleware",
|
||||||
"django.middleware.clickjacking.XFrameOptionsMiddleware",
|
"django.middleware.clickjacking.XFrameOptionsMiddleware",
|
||||||
|
|||||||
@@ -29,11 +29,6 @@ CORS_ALLOW_CREDENTIALS = True
|
|||||||
CORS_ALLOW_PRIVATE_NETWORK = True
|
CORS_ALLOW_PRIVATE_NETWORK = True
|
||||||
CSRF_COOKIE_SECURE = False
|
CSRF_COOKIE_SECURE = False
|
||||||
SESSION_COOKIE_SECURE = False
|
SESSION_COOKIE_SECURE = False
|
||||||
MIDDLEWARE = [
|
|
||||||
middleware
|
|
||||||
for middleware in MIDDLEWARE
|
|
||||||
if middleware != "django.middleware.csrf.CsrfViewMiddleware"
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def _normalize_local_host(host: str) -> str:
|
def _normalize_local_host(host: str) -> str:
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import logging
|
|||||||
from io import StringIO
|
from io import StringIO
|
||||||
|
|
||||||
from apps.core.middleware import (
|
from apps.core.middleware import (
|
||||||
|
ApiCsrfExemptMiddleware,
|
||||||
RequestIDMiddleware,
|
RequestIDMiddleware,
|
||||||
RequestLoggingMiddleware,
|
RequestLoggingMiddleware,
|
||||||
get_request_id,
|
get_request_id,
|
||||||
@@ -73,3 +74,23 @@ class RequestLoggingMiddlewareTest(APITestCase):
|
|||||||
response = middleware.process_response(request, HttpResponse(status=200))
|
response = middleware.process_response(request, HttpResponse(status=200))
|
||||||
self.assertIn("X-Request-ID", response)
|
self.assertIn("X-Request-ID", response)
|
||||||
self.assertIsNone(get_request_id())
|
self.assertIsNone(get_request_id())
|
||||||
|
|
||||||
|
|
||||||
|
class ApiCsrfExemptMiddlewareTest(APITestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.factory = RequestFactory()
|
||||||
|
self.middleware = ApiCsrfExemptMiddleware(lambda req: HttpResponse(status=200))
|
||||||
|
|
||||||
|
def test_api_path_skips_csrf_check(self):
|
||||||
|
request = self.factory.post("/api/v1/users/login/", data={})
|
||||||
|
|
||||||
|
response = self.middleware.process_view(request, lambda req: None, (), {})
|
||||||
|
|
||||||
|
self.assertIsNone(response)
|
||||||
|
|
||||||
|
def test_non_api_path_keeps_csrf_check(self):
|
||||||
|
request = self.factory.post("/admin/login/", data={})
|
||||||
|
|
||||||
|
response = self.middleware.process_view(request, lambda req: None, (), {})
|
||||||
|
|
||||||
|
self.assertEqual(response.status_code, 403)
|
||||||
|
|||||||
Reference in New Issue
Block a user