445 lines
16 KiB
Python
445 lines
16 KiB
Python
"""Tests for core views (health checks)"""
|
|
|
|
import sys
|
|
import types
|
|
from datetime import timedelta
|
|
|
|
from apps.core import views as core_views
|
|
from apps.core.views import HealthCheckView
|
|
from django.urls import reverse
|
|
from django.utils import timezone
|
|
from rest_framework import status
|
|
from rest_framework.test import APIRequestFactory, APITestCase
|
|
|
|
from tests.apps.user.factories import UserFactory
|
|
from tests.utils.fixtures import fake
|
|
|
|
|
|
class HealthCheckViewTest(APITestCase):
|
|
"""Tests for HealthCheckView"""
|
|
|
|
def test_health_check_url_reverse(self):
|
|
"""Test reverse URL resolution for health check"""
|
|
url = reverse("core:health")
|
|
self.assertEqual(url, "/health/")
|
|
|
|
def test_health_check_success(self):
|
|
"""Test health check returns healthy status"""
|
|
url = reverse("core:health")
|
|
response = self.client.get(url)
|
|
|
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
|
self.assertIn("status", response.data)
|
|
self.assertIn("version", response.data)
|
|
self.assertIn("checks", response.data)
|
|
self.assertIn("database", response.data["checks"])
|
|
|
|
def test_health_check_database_up(self):
|
|
"""Test health check reports database as up"""
|
|
url = reverse("core:health")
|
|
response = self.client.get(url)
|
|
|
|
self.assertEqual(response.data["checks"]["database"]["status"], "up")
|
|
self.assertIn("latency_ms", response.data["checks"]["database"])
|
|
|
|
def test_health_check_includes_celery_when_requested(self):
|
|
url = reverse("core:health")
|
|
response = self.client.get(url, {"include_celery": "true"})
|
|
self.assertIn("celery", response.data["checks"])
|
|
self.assertIn(response.data["checks"]["celery"]["status"], ["up", "down"])
|
|
|
|
def test_health_check_redis_present(self):
|
|
url = reverse("core:health")
|
|
response = self.client.get(url)
|
|
self.assertIn("redis", response.data["checks"])
|
|
self.assertIn(
|
|
response.data["checks"]["redis"]["status"], ["up", "down", "skipped"]
|
|
)
|
|
|
|
|
|
class HealthCheckStatusCombinationsTest(APITestCase):
|
|
def test_unhealthy_when_database_down(self):
|
|
class DownDbHealthCheck(HealthCheckView):
|
|
def _check_database(self): # type: ignore[override]
|
|
return {"status": "down", "error": "db"}
|
|
|
|
def _check_redis(self): # type: ignore[override]
|
|
return {"status": "up", "latency_ms": 1}
|
|
|
|
factory = APIRequestFactory()
|
|
request = factory.get("/health/")
|
|
response = DownDbHealthCheck.as_view()(request)
|
|
|
|
self.assertEqual(response.status_code, status.HTTP_503_SERVICE_UNAVAILABLE)
|
|
self.assertEqual(response.data["status"], "unhealthy")
|
|
|
|
def test_degraded_when_redis_down_and_celery_down(self):
|
|
class DegradedHealthCheck(HealthCheckView):
|
|
def _check_database(self): # type: ignore[override]
|
|
return {"status": "up", "latency_ms": 1}
|
|
|
|
def _check_redis(self): # type: ignore[override]
|
|
return {"status": "down", "error": "redis"}
|
|
|
|
def _check_celery(self): # type: ignore[override]
|
|
return {"status": "down", "error": "celery"}
|
|
|
|
factory = APIRequestFactory()
|
|
request = factory.get("/health/", {"include_celery": "true"})
|
|
response = DegradedHealthCheck.as_view()(request)
|
|
|
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
|
self.assertEqual(response.data["status"], "degraded")
|
|
|
|
def test_degraded_when_celery_down_only(self):
|
|
class CeleryDownHealthCheck(HealthCheckView):
|
|
def _check_database(self): # type: ignore[override]
|
|
return {"status": "up", "latency_ms": 1}
|
|
|
|
def _check_redis(self): # type: ignore[override]
|
|
return {"status": "up", "latency_ms": 1}
|
|
|
|
def _check_celery(self): # type: ignore[override]
|
|
return {"status": "down", "error": "celery"}
|
|
|
|
factory = APIRequestFactory()
|
|
request = factory.get("/health/", {"include_celery": "true"})
|
|
response = CeleryDownHealthCheck.as_view()(request)
|
|
|
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
|
self.assertEqual(response.data["status"], "degraded")
|
|
|
|
|
|
class HealthCheckInternalTests(APITestCase):
|
|
def test_check_database_returns_down_on_error(self):
|
|
original_connection = core_views.connection
|
|
|
|
class _BrokenCursor:
|
|
def __enter__(self):
|
|
raise RuntimeError("db down")
|
|
|
|
def __exit__(self, exc_type, exc, tb):
|
|
return False
|
|
|
|
class _BrokenConnection:
|
|
def cursor(self):
|
|
return _BrokenCursor()
|
|
|
|
try:
|
|
core_views.connection = _BrokenConnection()
|
|
result = HealthCheckView()._check_database()
|
|
finally:
|
|
core_views.connection = original_connection
|
|
|
|
self.assertEqual(result["status"], "down")
|
|
|
|
def test_check_redis_import_error(self):
|
|
original_module = sys.modules.get("django_redis")
|
|
sys.modules["django_redis"] = None
|
|
try:
|
|
result = HealthCheckView()._check_redis()
|
|
finally:
|
|
if original_module is None:
|
|
sys.modules.pop("django_redis", None)
|
|
else:
|
|
sys.modules["django_redis"] = original_module
|
|
|
|
self.assertEqual(result["status"], "skipped")
|
|
|
|
def test_check_redis_success(self):
|
|
original_module = sys.modules.get("django_redis")
|
|
|
|
class _FakeRedis:
|
|
def ping(self):
|
|
return True
|
|
|
|
fake_module = types.SimpleNamespace()
|
|
fake_module.get_redis_connection = lambda _alias: _FakeRedis()
|
|
|
|
sys.modules["django_redis"] = fake_module
|
|
try:
|
|
result = HealthCheckView()._check_redis()
|
|
finally:
|
|
if original_module is None:
|
|
sys.modules.pop("django_redis", None)
|
|
else:
|
|
sys.modules["django_redis"] = original_module
|
|
|
|
self.assertEqual(result["status"], "up")
|
|
|
|
def test_check_celery_up(self):
|
|
from core import celery as celery_module
|
|
|
|
original_app = celery_module.app
|
|
|
|
class _FakeInspector:
|
|
def active(self):
|
|
return {"worker": []}
|
|
|
|
class _FakeControl:
|
|
def inspect(self, timeout=None):
|
|
return _FakeInspector()
|
|
|
|
class _FakeApp:
|
|
control = _FakeControl()
|
|
|
|
try:
|
|
celery_module.app = _FakeApp()
|
|
result = HealthCheckView()._check_celery()
|
|
finally:
|
|
celery_module.app = original_app
|
|
|
|
self.assertEqual(result["status"], "up")
|
|
|
|
def test_check_celery_error(self):
|
|
from core import celery as celery_module
|
|
|
|
original_app = celery_module.app
|
|
|
|
class _BrokenControl:
|
|
def inspect(self, timeout=None):
|
|
raise RuntimeError("boom")
|
|
|
|
class _BrokenApp:
|
|
control = _BrokenControl()
|
|
|
|
try:
|
|
celery_module.app = _BrokenApp()
|
|
result = HealthCheckView()._check_celery()
|
|
finally:
|
|
celery_module.app = original_app
|
|
|
|
self.assertEqual(result["status"], "down")
|
|
|
|
|
|
class LivenessViewTest(APITestCase):
|
|
"""Tests for LivenessView"""
|
|
|
|
def test_liveness_url_reverse(self):
|
|
"""Test reverse URL resolution for liveness"""
|
|
url = reverse("core:liveness")
|
|
self.assertEqual(url, "/health/live/")
|
|
|
|
def test_liveness_returns_alive(self):
|
|
"""Test liveness probe returns alive status"""
|
|
url = reverse("core:liveness")
|
|
response = self.client.get(url)
|
|
|
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
|
self.assertEqual(response.data["status"], "alive")
|
|
|
|
|
|
class ReadinessViewTest(APITestCase):
|
|
"""Tests for ReadinessView"""
|
|
|
|
def test_readiness_url_reverse(self):
|
|
"""Test reverse URL resolution for readiness"""
|
|
url = reverse("core:readiness")
|
|
self.assertEqual(url, "/health/ready/")
|
|
|
|
def test_readiness_returns_ready(self):
|
|
"""Test readiness probe returns ready when DB is available"""
|
|
url = reverse("core:readiness")
|
|
response = self.client.get(url)
|
|
|
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
|
self.assertEqual(response.data["status"], "ready")
|
|
|
|
def test_readiness_returns_not_ready_on_db_error(self):
|
|
original_connection = core_views.connection
|
|
|
|
class _BrokenCursor:
|
|
def __enter__(self):
|
|
raise RuntimeError("db down")
|
|
|
|
def __exit__(self, exc_type, exc, tb):
|
|
return False
|
|
|
|
class _BrokenConnection:
|
|
def cursor(self):
|
|
return _BrokenCursor()
|
|
|
|
try:
|
|
core_views.connection = _BrokenConnection()
|
|
url = reverse("core:readiness")
|
|
response = self.client.get(url)
|
|
finally:
|
|
core_views.connection = original_connection
|
|
|
|
self.assertEqual(response.status_code, status.HTTP_503_SERVICE_UNAVAILABLE)
|
|
self.assertEqual(response.data["status"], "not_ready")
|
|
|
|
|
|
class APIVersioningURLTest(APITestCase):
|
|
"""Tests for API versioning URL structure"""
|
|
|
|
def test_api_v1_user_register_reverse(self):
|
|
"""Test reverse URL for user registration"""
|
|
url = reverse("api_v1:user:register")
|
|
self.assertEqual(url, "/api/v1/users/register/")
|
|
|
|
def test_api_v1_user_login_reverse(self):
|
|
"""Test reverse URL for user login"""
|
|
url = reverse("api_v1:user:login")
|
|
self.assertEqual(url, "/api/v1/users/login/")
|
|
|
|
def test_api_v1_user_logout_reverse(self):
|
|
"""Test reverse URL for user logout"""
|
|
url = reverse("api_v1:user:logout")
|
|
self.assertEqual(url, "/api/v1/users/logout/")
|
|
|
|
def test_api_v1_user_current_user_reverse(self):
|
|
"""Test reverse URL for current user"""
|
|
url = reverse("api_v1:user:current_user")
|
|
self.assertEqual(url, "/api/v1/users/me/")
|
|
|
|
def test_api_v1_user_token_refresh_reverse(self):
|
|
"""Test reverse URL for token refresh"""
|
|
url = reverse("api_v1:user:token_refresh")
|
|
self.assertEqual(url, "/api/v1/users/token/refresh/")
|
|
|
|
def test_api_v1_user_password_change_reverse(self):
|
|
"""Test reverse URL for password change"""
|
|
url = reverse("api_v1:user:password_change")
|
|
self.assertEqual(url, "/api/v1/users/password/change/")
|
|
|
|
|
|
class BackgroundJobsViewTest(APITestCase):
|
|
def setUp(self):
|
|
self.user = UserFactory.create_user()
|
|
self.other = UserFactory.create_user()
|
|
self.admin = UserFactory.create_superuser()
|
|
|
|
def _create_job(self, *, task_id: str, user_id: int | None, status: str):
|
|
from apps.core.models import BackgroundJob
|
|
|
|
started_at = timezone.now()
|
|
completed_at = started_at + timedelta(seconds=5)
|
|
return BackgroundJob.objects.create(
|
|
task_id=task_id,
|
|
task_name="apps.test.task",
|
|
status=status,
|
|
user_id=user_id,
|
|
started_at=started_at,
|
|
completed_at=completed_at,
|
|
)
|
|
|
|
def test_job_status_for_owner(self):
|
|
job = self._create_job(
|
|
task_id="job-owner", user_id=self.user.id, status="success"
|
|
)
|
|
self.client.force_authenticate(self.user)
|
|
url = reverse("api_v1:jobs:job-status", kwargs={"task_id": job.task_id})
|
|
response = self.client.get(url)
|
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
|
self.assertEqual(
|
|
set(response.data.keys()),
|
|
{"task_id", "status", "progress", "message", "result", "error"},
|
|
)
|
|
self.assertEqual(response.data["task_id"], job.task_id)
|
|
self.assertEqual(response.data["status"], "success")
|
|
|
|
def test_job_status_forbidden_for_other_user(self):
|
|
job = self._create_job(
|
|
task_id="job-forbidden", user_id=self.user.id, status="success"
|
|
)
|
|
self.client.force_authenticate(self.other)
|
|
url = reverse("api_v1:jobs:job-status", kwargs={"task_id": job.task_id})
|
|
response = self.client.get(url)
|
|
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
|
|
|
|
def test_job_status_for_admin(self):
|
|
job = self._create_job(
|
|
task_id="job-admin", user_id=self.user.id, status="success"
|
|
)
|
|
self.client.force_authenticate(self.admin)
|
|
url = reverse("api_v1:jobs:job-status", kwargs={"task_id": job.task_id})
|
|
response = self.client.get(url)
|
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
|
|
|
def test_job_status_forbidden_for_unowned_job(self):
|
|
job = self._create_job(task_id="job-unowned", user_id=None, status="success")
|
|
self.client.force_authenticate(self.other)
|
|
url = reverse("api_v1:jobs:job-status", kwargs={"task_id": job.task_id})
|
|
response = self.client.get(url)
|
|
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
|
|
|
|
def test_job_list_filters_status(self):
|
|
self._create_job(task_id="job-1", user_id=self.user.id, status="success")
|
|
self._create_job(task_id="job-2", user_id=self.user.id, status="started")
|
|
self.client.force_authenticate(self.user)
|
|
url = reverse("api_v1:jobs:job-list")
|
|
response = self.client.get(url, {"status": "success", "limit": 10})
|
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
|
self.assertEqual(len(response.data["results"]), 1)
|
|
self.assertEqual(
|
|
set(response.data["results"][0].keys()),
|
|
{"task_id", "status", "progress", "name"},
|
|
)
|
|
self.assertEqual(response.data["results"][0]["name"], "task")
|
|
|
|
def test_job_list_limit(self):
|
|
for idx in range(5):
|
|
self._create_job(
|
|
task_id=f"job-{idx}-{fake.random_int(min=1, max=9999)}",
|
|
user_id=self.user.id,
|
|
status="success",
|
|
)
|
|
self.client.force_authenticate(self.user)
|
|
url = reverse("api_v1:jobs:job-list")
|
|
response = self.client.get(url, {"limit": 2})
|
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
|
self.assertLessEqual(len(response.data["results"]), 2)
|
|
|
|
def test_job_list_invalid_limit_returns_400(self):
|
|
self._create_job(
|
|
task_id="job-invalid-limit", user_id=self.user.id, status="success"
|
|
)
|
|
self.client.force_authenticate(self.user)
|
|
url = reverse("api_v1:jobs:job-list")
|
|
response = self.client.get(url, {"limit": "abc"})
|
|
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
|
|
|
def test_job_status_maps_started_to_running(self):
|
|
job = self._create_job(
|
|
task_id="job-running",
|
|
user_id=self.user.id,
|
|
status="started",
|
|
)
|
|
self.client.force_authenticate(self.user)
|
|
url = reverse("api_v1:jobs:job-status", kwargs={"task_id": job.task_id})
|
|
|
|
response = self.client.get(url)
|
|
|
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
|
self.assertEqual(response.data["status"], "running")
|
|
|
|
def test_job_stream_returns_completed_sse_event(self):
|
|
job = self._create_job(
|
|
task_id="job-stream-complete",
|
|
user_id=self.user.id,
|
|
status="success",
|
|
)
|
|
self.client.force_authenticate(self.user)
|
|
url = reverse("api_v1:jobs:job-stream", kwargs={"task_id": job.task_id})
|
|
|
|
response = self.client.get(url)
|
|
|
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
|
chunks = b"".join(response.streaming_content).decode("utf-8")
|
|
self.assertIn("event: completed", chunks)
|
|
self.assertIn('"task_id": "job-stream-complete"', chunks)
|
|
|
|
def test_job_stream_forbidden_for_other_user(self):
|
|
job = self._create_job(
|
|
task_id="job-stream-forbidden",
|
|
user_id=self.user.id,
|
|
status="success",
|
|
)
|
|
self.client.force_authenticate(self.other)
|
|
url = reverse("api_v1:jobs:job-stream", kwargs={"task_id": job.task_id})
|
|
|
|
response = self.client.get(url)
|
|
|
|
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
|