"""Tests for core ViewSets""" from __future__ import annotations from typing import Any from apps.core.pagination import StandardPagination from apps.core.viewsets import ( BaseViewSet, BulkMixin, OwnerViewSet, ReadOnlyViewSet, ) from apps.parsers.models import Proxy from apps.user.models import Profile, User from django.test import TestCase, override_settings from django.urls import include, path from rest_framework import serializers, status, viewsets from rest_framework.decorators import action from rest_framework.permissions import IsAuthenticated from rest_framework.routers import DefaultRouter from rest_framework.test import APITestCase from tests.apps.parsers.factories import ProxyFactory, fake from tests.apps.user.factories import ProfileFactory, UserFactory def _proxy_payload() -> dict[str, Any]: proxy = ProxyFactory.build() return { "address": proxy.address, "is_active": proxy.is_active, "fail_count": proxy.fail_count, "description": proxy.description, } class ProxySerializer(serializers.ModelSerializer): class Meta: model = Proxy fields = ["id", "address", "is_active", "fail_count", "description"] class ProxyListSerializer(serializers.ModelSerializer): class Meta: model = Proxy fields = ["id", "address"] class ProfileSerializer(serializers.ModelSerializer): class Meta: model = Profile fields = ["id", "user", "first_name", "last_name", "bio"] read_only_fields = ["user"] class UserSerializer(serializers.ModelSerializer): class Meta: model = User fields = ["id", "email", "username"] class ProxyViewSet(BaseViewSet[Proxy]): queryset = Proxy.objects.all() serializer_class = ProxySerializer serializer_classes = {"list": ProxyListSerializer} only_fields = ["id", "address"] class DeferProxyViewSet(BaseViewSet[Proxy]): queryset = Proxy.objects.all() serializer_class = ProxySerializer defer_fields = ["description"] class ReadOnlyProxyViewSet(ReadOnlyViewSet[Proxy]): queryset = Proxy.objects.all() serializer_class = ProxySerializer class NoPaginationProxyViewSet(BaseViewSet[Proxy]): queryset = Proxy.objects.all() serializer_class = ProxySerializer pagination_class = None class ProfileSelectViewSet(BaseViewSet[Profile]): queryset = Profile.objects.all() serializer_class = ProfileSerializer select_related_fields = ["user"] class ProfileOldStyleViewSet(BaseViewSet[Profile]): queryset = Profile.objects.all() serializer_class = ProfileSerializer _select_related = ["user"] class UserPrefetchViewSet(BaseViewSet[User]): queryset = User.objects.all() serializer_class = UserSerializer prefetch_related_fields = ["groups"] class UserOldStyleViewSet(BaseViewSet[User]): queryset = User.objects.all() serializer_class = UserSerializer _prefetch_related = ["groups"] class OwnerProfileViewSet(OwnerViewSet[Profile]): queryset = Profile.objects.all() serializer_class = ProfileSerializer class BulkProxyViewSet(BulkMixin, BaseViewSet[Proxy]): queryset = Proxy.objects.all() serializer_class = ProxySerializer bulk_max_items = 2 @action(detail=False, methods=["post"]) def bulk_create(self, request): return super().bulk_create(request) @action(detail=False, methods=["patch"]) def bulk_update(self, request): return super().bulk_update(request) @action(detail=False, methods=["delete"]) def bulk_delete(self, request): return super().bulk_delete(request) router = DefaultRouter() router.register("proxies", ProxyViewSet, basename="proxy") router.register("proxies-defer", DeferProxyViewSet, basename="proxy-defer") router.register("proxies-readonly", ReadOnlyProxyViewSet, basename="proxy-readonly") router.register("proxies-nopage", NoPaginationProxyViewSet, basename="proxy-nopage") router.register("profiles-select", ProfileSelectViewSet, basename="profile-select") router.register("profiles-old", ProfileOldStyleViewSet, basename="profile-old") router.register("users-prefetch", UserPrefetchViewSet, basename="user-prefetch") router.register("users-old", UserOldStyleViewSet, basename="user-old") router.register("profiles-owner", OwnerProfileViewSet, basename="profile-owner") router.register("bulk-proxies", BulkProxyViewSet, basename="bulk-proxy") urlpatterns = [path("", include(router.urls))] class BaseViewSetTest(TestCase): """Tests for BaseViewSet""" def test_inherits_from_model_viewset(self): """Test BaseViewSet inherits from ModelViewSet""" self.assertTrue(issubclass(BaseViewSet, viewsets.ModelViewSet)) def test_has_pagination_class(self): """Test BaseViewSet has pagination_class""" self.assertEqual(BaseViewSet.pagination_class, StandardPagination) def test_has_permission_classes(self): """Test BaseViewSet has permission_classes""" self.assertIn(IsAuthenticated, BaseViewSet.permission_classes) def test_has_filter_backends(self): """Test BaseViewSet has filter_backends""" self.assertTrue(hasattr(BaseViewSet, "filter_backends")) self.assertIsInstance(BaseViewSet.filter_backends, list) self.assertTrue(len(BaseViewSet.filter_backends) > 0) def test_has_default_ordering(self): """Test BaseViewSet has default ordering""" self.assertEqual(BaseViewSet.ordering, ["-created_at"]) def test_has_serializer_classes_dict(self): """Test BaseViewSet has serializer_classes dict""" self.assertTrue(hasattr(BaseViewSet, "serializer_classes")) self.assertIsInstance(BaseViewSet.serializer_classes, dict) class ReadOnlyViewSetTest(TestCase): """Tests for ReadOnlyViewSet""" def test_inherits_from_readonly_model_viewset(self): """Test ReadOnlyViewSet inherits from ReadOnlyModelViewSet""" self.assertTrue(issubclass(ReadOnlyViewSet, viewsets.ReadOnlyModelViewSet)) def test_has_pagination_class(self): """Test ReadOnlyViewSet has pagination_class""" self.assertEqual(ReadOnlyViewSet.pagination_class, StandardPagination) def test_has_filter_backends(self): """Test ReadOnlyViewSet has filter_backends""" self.assertTrue(hasattr(ReadOnlyViewSet, "filter_backends")) self.assertTrue(len(ReadOnlyViewSet.filter_backends) > 0) class OwnerViewSetTest(TestCase): """Tests for OwnerViewSet""" def test_inherits_from_base_viewset(self): """Test OwnerViewSet inherits from BaseViewSet""" self.assertTrue(issubclass(OwnerViewSet, BaseViewSet)) def test_has_owner_field(self): """Test OwnerViewSet has owner_field attribute""" self.assertEqual(OwnerViewSet.owner_field, "user") class BulkMixinTest(TestCase): """Tests for BulkMixin""" def test_has_bulk_create_method(self): """Test BulkMixin has bulk_create method""" self.assertTrue(hasattr(BulkMixin, "bulk_create")) self.assertTrue(callable(BulkMixin.bulk_create)) def test_has_bulk_delete_method(self): """Test BulkMixin has bulk_delete method""" self.assertTrue(hasattr(BulkMixin, "bulk_delete")) self.assertTrue(callable(BulkMixin.bulk_delete)) @override_settings(ROOT_URLCONF=__name__) class BaseViewSetIntegrationTest(APITestCase): def setUp(self): self.user = UserFactory.create_user() self.client.force_authenticate(self.user) def test_list_paginated_uses_list_serializer(self): ProxyFactory.create_batch(3) response = self.client.get("/proxies/?page=1&page_size=2") self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertTrue(response.data["success"]) self.assertEqual(len(response.data["data"]), 2) self.assertIn("pagination", response.data["meta"]) self.assertSetEqual( set(response.data["data"][0].keys()), {"id", "address"} ) def test_list_without_pagination(self): ProxyFactory.create_batch(2) response = self.client.get("/proxies-nopage/") self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertTrue(response.data["success"]) self.assertEqual(len(response.data["data"]), 2) self.assertIsNone(response.data["meta"]) def test_retrieve_uses_default_serializer(self): proxy = ProxyFactory() response = self.client.get(f"/proxies/{proxy.pk}/") self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertIn("fail_count", response.data["data"]) def test_create_update_delete(self): payload = _proxy_payload() created = self.client.post("/proxies/", payload, format="json") self.assertEqual(created.status_code, status.HTTP_201_CREATED) proxy_id = created.data["data"]["id"] new_description = fake.sentence(nb_words=3) updated = self.client.patch( f"/proxies/{proxy_id}/", {"description": new_description}, format="json", ) self.assertEqual(updated.status_code, status.HTTP_200_OK) self.assertEqual(updated.data["data"]["description"], new_description) deleted = self.client.delete(f"/proxies/{proxy_id}/") self.assertEqual(deleted.status_code, status.HTTP_204_NO_CONTENT) @override_settings(ROOT_URLCONF=__name__) class ReadOnlyViewSetIntegrationTest(APITestCase): def setUp(self): self.user = UserFactory.create_user() self.client.force_authenticate(self.user) def test_readonly_list_and_retrieve(self): proxy = ProxyFactory() ProxyFactory.create_batch(2) list_response = self.client.get("/proxies-readonly/") self.assertEqual(list_response.status_code, status.HTTP_200_OK) self.assertTrue(list_response.data["success"]) detail_response = self.client.get(f"/proxies-readonly/{proxy.pk}/") self.assertEqual(detail_response.status_code, status.HTTP_200_OK) self.assertEqual(detail_response.data["data"]["id"], proxy.pk) @override_settings(ROOT_URLCONF=__name__) class OwnerViewSetIntegrationTest(APITestCase): def setUp(self): self.user = UserFactory.create_user() self.other_user = UserFactory.create_user() self.client.force_authenticate(self.user) def test_list_filters_by_owner(self): ProfileFactory.create_profile(user=self.user) ProfileFactory.create_profile(user=self.other_user) response = self.client.get("/profiles-owner/") self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(len(response.data["data"]), 1) def test_create_sets_owner(self): user = UserFactory.create_user() user.profile.delete() self.client.force_authenticate(user) response = self.client.post( "/profiles-owner/", {"first_name": fake.first_name(), "last_name": fake.last_name()}, format="json", ) self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(response.data["data"]["user"], user.id) @override_settings(ROOT_URLCONF=__name__) class BulkMixinIntegrationTest(APITestCase): def setUp(self): self.user = UserFactory.create_user() self.client.force_authenticate(self.user) def test_bulk_create_empty_items(self): response = self.client.post("/bulk-proxies/bulk_create/", {}, format="json") self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertFalse(response.data["success"]) def test_bulk_create_too_many(self): items = [_proxy_payload() for _ in range(3)] response = self.client.post( "/bulk-proxies/bulk_create/", {"items": items}, format="json" ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(response.data["errors"][0]["code"], "too_many_items") def test_bulk_create_update_delete(self): items = [_proxy_payload(), _proxy_payload()] created = self.client.post( "/bulk-proxies/bulk_create/", {"items": items}, format="json" ) self.assertEqual(created.status_code, status.HTTP_201_CREATED) created_ids = [item["id"] for item in created.data["data"]] update_items = [ {"id": created_ids[0], "description": fake.sentence(nb_words=2)}, { "id": fake.random_int(min=999999, max=9999999), "description": fake.word(), }, ] updated = self.client.patch( "/bulk-proxies/bulk_update/", {"items": update_items}, format="json" ) self.assertEqual(updated.status_code, status.HTTP_200_OK) self.assertEqual(len(updated.data["data"]["updated"]), 1) self.assertEqual(len(updated.data["data"]["errors"]), 1) deleted = self.client.delete( "/bulk-proxies/bulk_delete/", {"ids": created_ids}, format="json" ) self.assertEqual(deleted.status_code, status.HTTP_200_OK) self.assertEqual(deleted.data["data"]["deleted"], len(created_ids)) def test_bulk_update_missing_ids(self): response = self.client.patch( "/bulk-proxies/bulk_update/", {"items": [{"address": fake.word()}]}, format="json", ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(response.data["errors"][0]["code"], "missing_ids") @override_settings(ROOT_URLCONF=__name__) class QuerysetOptimizationIntegrationTest(APITestCase): def setUp(self): self.user = UserFactory.create_user() self.client.force_authenticate(self.user) def test_select_related_and_old_style(self): ProfileFactory.create_profile(user=self.user) response_new = self.client.get("/profiles-select/") self.assertEqual(response_new.status_code, status.HTTP_200_OK) response_old = self.client.get("/profiles-old/") self.assertEqual(response_old.status_code, status.HTTP_200_OK) def test_prefetch_related_and_old_style(self): UserFactory.create_user() response_new = self.client.get("/users-prefetch/") self.assertEqual(response_new.status_code, status.HTTP_200_OK) response_old = self.client.get("/users-old/") self.assertEqual(response_old.status_code, status.HTTP_200_OK) def test_defer_fields_branch(self): ProxyFactory.create_batch(2) response = self.client.get("/proxies-defer/") self.assertEqual(response.status_code, status.HTTP_200_OK)