fix(api): accept slashless api urls
This commit is contained in:
@@ -9,6 +9,7 @@ import threading
|
|||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from django.middleware.csrf import CsrfViewMiddleware
|
from django.middleware.csrf import CsrfViewMiddleware
|
||||||
|
from django.urls import Resolver404, resolve
|
||||||
from django.utils.deprecation import MiddlewareMixin
|
from django.utils.deprecation import MiddlewareMixin
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -28,6 +29,34 @@ class ApiCsrfExemptMiddleware(CsrfViewMiddleware):
|
|||||||
return super().process_view(request, callback, callback_args, callback_kwargs)
|
return super().process_view(request, callback, callback_args, callback_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class ApiSlashlessRouteMiddleware(MiddlewareMixin):
|
||||||
|
"""Route slashless API URLs to existing slash URLs without POST redirects."""
|
||||||
|
|
||||||
|
api_prefixes = ("/api/",)
|
||||||
|
|
||||||
|
def process_request(self, request):
|
||||||
|
path_info = request.path_info
|
||||||
|
if not path_info.startswith(self.api_prefixes) or path_info.endswith("/"):
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
resolve(path_info)
|
||||||
|
return None
|
||||||
|
except Resolver404:
|
||||||
|
pass
|
||||||
|
|
||||||
|
slash_path_info = f"{path_info}/"
|
||||||
|
try:
|
||||||
|
resolve(slash_path_info)
|
||||||
|
except Resolver404:
|
||||||
|
return None
|
||||||
|
|
||||||
|
request.path_info = slash_path_info
|
||||||
|
request.path = f"{request.path}/"
|
||||||
|
request.META["PATH_INFO"] = slash_path_info
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_request_id() -> str | None:
|
def get_request_id() -> str | None:
|
||||||
"""Get current request ID from thread-local storage."""
|
"""Get current request ID from thread-local storage."""
|
||||||
return getattr(_request_context, "request_id", None)
|
return getattr(_request_context, "request_id", None)
|
||||||
|
|||||||
@@ -166,6 +166,7 @@ MIDDLEWARE = [
|
|||||||
"django.middleware.security.SecurityMiddleware",
|
"django.middleware.security.SecurityMiddleware",
|
||||||
"whitenoise.middleware.WhiteNoiseMiddleware",
|
"whitenoise.middleware.WhiteNoiseMiddleware",
|
||||||
"django.contrib.sessions.middleware.SessionMiddleware",
|
"django.contrib.sessions.middleware.SessionMiddleware",
|
||||||
|
"apps.core.middleware.ApiSlashlessRouteMiddleware",
|
||||||
"django.middleware.common.CommonMiddleware",
|
"django.middleware.common.CommonMiddleware",
|
||||||
"apps.core.middleware.ApiCsrfExemptMiddleware",
|
"apps.core.middleware.ApiCsrfExemptMiddleware",
|
||||||
"django.contrib.auth.middleware.AuthenticationMiddleware",
|
"django.contrib.auth.middleware.AuthenticationMiddleware",
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from io import StringIO
|
|||||||
|
|
||||||
from apps.core.middleware import (
|
from apps.core.middleware import (
|
||||||
ApiCsrfExemptMiddleware,
|
ApiCsrfExemptMiddleware,
|
||||||
|
ApiSlashlessRouteMiddleware,
|
||||||
RequestIDMiddleware,
|
RequestIDMiddleware,
|
||||||
RequestLoggingMiddleware,
|
RequestLoggingMiddleware,
|
||||||
get_request_id,
|
get_request_id,
|
||||||
@@ -94,3 +95,34 @@ class ApiCsrfExemptMiddlewareTest(APITestCase):
|
|||||||
response = self.middleware.process_view(request, lambda req: None, (), {})
|
response = self.middleware.process_view(request, lambda req: None, (), {})
|
||||||
|
|
||||||
self.assertEqual(response.status_code, 403)
|
self.assertEqual(response.status_code, 403)
|
||||||
|
|
||||||
|
|
||||||
|
class ApiSlashlessRouteMiddlewareTest(APITestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.factory = RequestFactory()
|
||||||
|
self.middleware = ApiSlashlessRouteMiddleware(
|
||||||
|
lambda req: HttpResponse(status=200)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_api_path_without_slash_rewrites_to_existing_slash_route(self):
|
||||||
|
request = self.factory.post("/api/v1/users/login", data={})
|
||||||
|
|
||||||
|
response = self.middleware.process_request(request)
|
||||||
|
|
||||||
|
self.assertIsNone(response)
|
||||||
|
self.assertEqual(request.path_info, "/api/v1/users/login/")
|
||||||
|
self.assertEqual(request.path, "/api/v1/users/login/")
|
||||||
|
|
||||||
|
def test_unknown_api_path_without_slash_is_unchanged(self):
|
||||||
|
request = self.factory.post("/api/v1/unknown-route", data={})
|
||||||
|
|
||||||
|
self.middleware.process_request(request)
|
||||||
|
|
||||||
|
self.assertEqual(request.path_info, "/api/v1/unknown-route")
|
||||||
|
|
||||||
|
def test_non_api_path_without_slash_is_unchanged(self):
|
||||||
|
request = self.factory.post("/admin/login", data={})
|
||||||
|
|
||||||
|
self.middleware.process_request(request)
|
||||||
|
|
||||||
|
self.assertEqual(request.path_info, "/admin/login")
|
||||||
|
|||||||
@@ -113,6 +113,16 @@ class LoginViewTest(APITestCase):
|
|||||||
self.assertIn("refresh", response.data)
|
self.assertIn("refresh", response.data)
|
||||||
self.assertIn("access", response.data)
|
self.assertIn("access", response.data)
|
||||||
|
|
||||||
|
def test_login_accepts_slashless_api_url(self):
|
||||||
|
"""Frontend clients can call API URLs without Django's trailing slash."""
|
||||||
|
response = self.client.post(
|
||||||
|
self.login_url.rstrip("/"), self.login_data, format="json"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||||
|
self.assertIn("refresh", response.data)
|
||||||
|
self.assertIn("access", response.data)
|
||||||
|
|
||||||
def test_login_invalid_credentials(self):
|
def test_login_invalid_credentials(self):
|
||||||
"""Test login fails with invalid credentials"""
|
"""Test login fails with invalid credentials"""
|
||||||
data = self.login_data.copy()
|
data = self.login_data.copy()
|
||||||
|
|||||||
Reference in New Issue
Block a user