diff --git a/src/apps/core/middleware.py b/src/apps/core/middleware.py index ed88e52..44c4613 100644 --- a/src/apps/core/middleware.py +++ b/src/apps/core/middleware.py @@ -9,6 +9,7 @@ import threading import uuid from django.middleware.csrf import CsrfViewMiddleware +from django.urls import Resolver404, resolve from django.utils.deprecation import MiddlewareMixin logger = logging.getLogger(__name__) @@ -28,6 +29,34 @@ class ApiCsrfExemptMiddleware(CsrfViewMiddleware): return super().process_view(request, callback, callback_args, callback_kwargs) +class ApiSlashlessRouteMiddleware(MiddlewareMixin): + """Route slashless API URLs to existing slash URLs without POST redirects.""" + + api_prefixes = ("/api/",) + + def process_request(self, request): + path_info = request.path_info + if not path_info.startswith(self.api_prefixes) or path_info.endswith("/"): + return None + + try: + resolve(path_info) + return None + except Resolver404: + pass + + slash_path_info = f"{path_info}/" + try: + resolve(slash_path_info) + except Resolver404: + return None + + request.path_info = slash_path_info + request.path = f"{request.path}/" + request.META["PATH_INFO"] = slash_path_info + return None + + def get_request_id() -> str | None: """Get current request ID from thread-local storage.""" return getattr(_request_context, "request_id", None) diff --git a/src/settings/base.py b/src/settings/base.py index ef9e260..42ccf1e 100644 --- a/src/settings/base.py +++ b/src/settings/base.py @@ -166,6 +166,7 @@ MIDDLEWARE = [ "django.middleware.security.SecurityMiddleware", "whitenoise.middleware.WhiteNoiseMiddleware", "django.contrib.sessions.middleware.SessionMiddleware", + "apps.core.middleware.ApiSlashlessRouteMiddleware", "django.middleware.common.CommonMiddleware", "apps.core.middleware.ApiCsrfExemptMiddleware", "django.contrib.auth.middleware.AuthenticationMiddleware", diff --git a/tests/apps/core/test_middleware.py b/tests/apps/core/test_middleware.py index 421d009..3c2ccf9 100644 --- a/tests/apps/core/test_middleware.py +++ b/tests/apps/core/test_middleware.py @@ -5,6 +5,7 @@ from io import StringIO from apps.core.middleware import ( ApiCsrfExemptMiddleware, + ApiSlashlessRouteMiddleware, RequestIDMiddleware, RequestLoggingMiddleware, get_request_id, @@ -94,3 +95,34 @@ class ApiCsrfExemptMiddlewareTest(APITestCase): response = self.middleware.process_view(request, lambda req: None, (), {}) self.assertEqual(response.status_code, 403) + + +class ApiSlashlessRouteMiddlewareTest(APITestCase): + def setUp(self): + self.factory = RequestFactory() + self.middleware = ApiSlashlessRouteMiddleware( + lambda req: HttpResponse(status=200) + ) + + def test_api_path_without_slash_rewrites_to_existing_slash_route(self): + request = self.factory.post("/api/v1/users/login", data={}) + + response = self.middleware.process_request(request) + + self.assertIsNone(response) + self.assertEqual(request.path_info, "/api/v1/users/login/") + self.assertEqual(request.path, "/api/v1/users/login/") + + def test_unknown_api_path_without_slash_is_unchanged(self): + request = self.factory.post("/api/v1/unknown-route", data={}) + + self.middleware.process_request(request) + + self.assertEqual(request.path_info, "/api/v1/unknown-route") + + def test_non_api_path_without_slash_is_unchanged(self): + request = self.factory.post("/admin/login", data={}) + + self.middleware.process_request(request) + + self.assertEqual(request.path_info, "/admin/login") diff --git a/tests/apps/user/test_views.py b/tests/apps/user/test_views.py index aefcf79..8a115c7 100644 --- a/tests/apps/user/test_views.py +++ b/tests/apps/user/test_views.py @@ -113,6 +113,16 @@ class LoginViewTest(APITestCase): self.assertIn("refresh", response.data) self.assertIn("access", response.data) + def test_login_accepts_slashless_api_url(self): + """Frontend clients can call API URLs without Django's trailing slash.""" + response = self.client.post( + self.login_url.rstrip("/"), self.login_data, format="json" + ) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertIn("refresh", response.data) + self.assertIn("access", response.data) + def test_login_invalid_credentials(self): """Test login fails with invalid credentials""" data = self.login_data.copy()