"""Tests for core ViewSets""" from __future__ import annotations from typing import Any from apps.core.pagination import StandardPagination from apps.core.viewsets import ( BaseViewSet, BulkMixin, OwnerViewSet, ReadOnlyViewSet, ) from apps.organization.models import Organization from apps.user.models import Profile, User from django.test import TestCase, override_settings from django.urls import include, path from rest_framework import serializers, status, viewsets from rest_framework.decorators import action from rest_framework.permissions import IsAuthenticated from rest_framework.routers import DefaultRouter from rest_framework.test import APITestCase from tests.apps.organization.factories import OrganizationFactory, fake from tests.apps.user.factories import ProfileFactory, UserFactory def _organization_payload() -> dict[str, Any]: organization = OrganizationFactory.build() return { "name": organization.name, "inn": fake.unique.numerify("##########"), "ogrn": organization.ogrn, "kpp": organization.kpp, "okpo": organization.okpo, } class OrganizationSerializer(serializers.ModelSerializer): class Meta: model = Organization fields = ["id", "name", "inn", "ogrn", "kpp", "okpo"] class OrganizationListSerializer(serializers.ModelSerializer): class Meta: model = Organization fields = ["id", "name"] class ProfileSerializer(serializers.ModelSerializer): class Meta: model = Profile fields = ["id", "user", "first_name", "last_name", "bio"] read_only_fields = ["user"] class UserSerializer(serializers.ModelSerializer): class Meta: model = User fields = ["id", "email", "username"] class OrganizationViewSet(BaseViewSet[Organization]): queryset = Organization.objects.all() serializer_class = OrganizationSerializer serializer_classes = {"list": OrganizationListSerializer} only_fields = ["id", "name"] class DeferOrganizationViewSet(BaseViewSet[Organization]): queryset = Organization.objects.all() serializer_class = OrganizationSerializer defer_fields = ["okpo"] class ReadOnlyOrganizationViewSet(ReadOnlyViewSet[Organization]): queryset = Organization.objects.all() serializer_class = OrganizationSerializer class NoPaginationOrganizationViewSet(BaseViewSet[Organization]): queryset = Organization.objects.all() serializer_class = OrganizationSerializer pagination_class = None class ProfileSelectViewSet(BaseViewSet[Profile]): queryset = Profile.objects.all() serializer_class = ProfileSerializer select_related_fields = ["user"] class ProfileOldStyleViewSet(BaseViewSet[Profile]): queryset = Profile.objects.all() serializer_class = ProfileSerializer _select_related = ["user"] class UserPrefetchViewSet(BaseViewSet[User]): queryset = User.objects.all() serializer_class = UserSerializer prefetch_related_fields = ["groups"] class UserOldStyleViewSet(BaseViewSet[User]): queryset = User.objects.all() serializer_class = UserSerializer _prefetch_related = ["groups"] class OwnerProfileViewSet(OwnerViewSet[Profile]): queryset = Profile.objects.all() serializer_class = ProfileSerializer class BulkOrganizationViewSet(BulkMixin, BaseViewSet[Organization]): queryset = Organization.objects.all() serializer_class = OrganizationSerializer bulk_max_items = 2 @action(detail=False, methods=["post"]) def bulk_create(self, request): return super().bulk_create(request) @action(detail=False, methods=["patch"]) def bulk_update(self, request): return super().bulk_update(request) @action(detail=False, methods=["delete"]) def bulk_delete(self, request): return super().bulk_delete(request) router = DefaultRouter() router.register("organizations", OrganizationViewSet, basename="organization") router.register( "organizations-defer", DeferOrganizationViewSet, basename="organization-defer", ) router.register( "organizations-readonly", ReadOnlyOrganizationViewSet, basename="organization-readonly", ) router.register( "organizations-nopage", NoPaginationOrganizationViewSet, basename="organization-nopage", ) router.register("profiles-select", ProfileSelectViewSet, basename="profile-select") router.register("profiles-old", ProfileOldStyleViewSet, basename="profile-old") router.register("users-prefetch", UserPrefetchViewSet, basename="user-prefetch") router.register("users-old", UserOldStyleViewSet, basename="user-old") router.register("profiles-owner", OwnerProfileViewSet, basename="profile-owner") router.register( "bulk-organizations", BulkOrganizationViewSet, basename="bulk-organization", ) urlpatterns = [path("", include(router.urls))] class BaseViewSetTest(TestCase): """Tests for BaseViewSet""" def test_inherits_from_model_viewset(self): """Test BaseViewSet inherits from ModelViewSet""" self.assertTrue(issubclass(BaseViewSet, viewsets.ModelViewSet)) def test_has_pagination_class(self): """Test BaseViewSet has pagination_class""" self.assertEqual(BaseViewSet.pagination_class, StandardPagination) def test_has_permission_classes(self): """Test BaseViewSet has permission_classes""" self.assertIn(IsAuthenticated, BaseViewSet.permission_classes) def test_has_filter_backends(self): """Test BaseViewSet has filter_backends""" self.assertTrue(hasattr(BaseViewSet, "filter_backends")) self.assertIsInstance(BaseViewSet.filter_backends, list) self.assertTrue(len(BaseViewSet.filter_backends) > 0) def test_has_default_ordering(self): """Test BaseViewSet has default ordering""" self.assertEqual(BaseViewSet.ordering, ["-created_at"]) def test_has_serializer_classes_dict(self): """Test BaseViewSet has serializer_classes dict""" self.assertTrue(hasattr(BaseViewSet, "serializer_classes")) self.assertIsInstance(BaseViewSet.serializer_classes, dict) class ReadOnlyViewSetTest(TestCase): """Tests for ReadOnlyViewSet""" def test_inherits_from_readonly_model_viewset(self): """Test ReadOnlyViewSet inherits from ReadOnlyModelViewSet""" self.assertTrue(issubclass(ReadOnlyViewSet, viewsets.ReadOnlyModelViewSet)) def test_has_pagination_class(self): """Test ReadOnlyViewSet has pagination_class""" self.assertEqual(ReadOnlyViewSet.pagination_class, StandardPagination) def test_has_filter_backends(self): """Test ReadOnlyViewSet has filter_backends""" self.assertTrue(hasattr(ReadOnlyViewSet, "filter_backends")) self.assertTrue(len(ReadOnlyViewSet.filter_backends) > 0) class OwnerViewSetTest(TestCase): """Tests for OwnerViewSet""" def test_inherits_from_base_viewset(self): """Test OwnerViewSet inherits from BaseViewSet""" self.assertTrue(issubclass(OwnerViewSet, BaseViewSet)) def test_has_owner_field(self): """Test OwnerViewSet has owner_field attribute""" self.assertEqual(OwnerViewSet.owner_field, "user") class BulkMixinTest(TestCase): """Tests for BulkMixin""" def test_has_bulk_create_method(self): """Test BulkMixin has bulk_create method""" self.assertTrue(hasattr(BulkMixin, "bulk_create")) self.assertTrue(callable(BulkMixin.bulk_create)) def test_has_bulk_delete_method(self): """Test BulkMixin has bulk_delete method""" self.assertTrue(hasattr(BulkMixin, "bulk_delete")) self.assertTrue(callable(BulkMixin.bulk_delete)) @override_settings(ROOT_URLCONF=__name__) class BaseViewSetIntegrationTest(APITestCase): def setUp(self): self.user = UserFactory.create_user() self.client.force_authenticate(self.user) def test_list_paginated_uses_list_serializer(self): OrganizationFactory.create_batch(3) response = self.client.get("/organizations/?page=1&page_size=2") self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertTrue(response.data["success"]) self.assertEqual(len(response.data["data"]), 2) self.assertIn("pagination", response.data["meta"]) self.assertSetEqual(set(response.data["data"][0].keys()), {"id", "name"}) def test_list_without_pagination(self): OrganizationFactory.create_batch(2) response = self.client.get("/organizations-nopage/") self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertTrue(response.data["success"]) self.assertEqual(len(response.data["data"]), 2) self.assertIsNone(response.data["meta"]) def test_retrieve_uses_default_serializer(self): organization = OrganizationFactory() response = self.client.get(f"/organizations/{organization.pk}/") self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertIn("inn", response.data["data"]) def test_create_update_delete(self): payload = _organization_payload() created = self.client.post("/organizations/", payload, format="json") self.assertEqual(created.status_code, status.HTTP_201_CREATED) organization_id = created.data["data"]["id"] new_name = fake.company() updated = self.client.patch( f"/organizations/{organization_id}/", {"name": new_name}, format="json", ) self.assertEqual(updated.status_code, status.HTTP_200_OK) self.assertEqual(updated.data["data"]["name"], new_name) deleted = self.client.delete(f"/organizations/{organization_id}/") self.assertEqual(deleted.status_code, status.HTTP_204_NO_CONTENT) @override_settings(ROOT_URLCONF=__name__) class ReadOnlyViewSetIntegrationTest(APITestCase): def setUp(self): self.user = UserFactory.create_user() self.client.force_authenticate(self.user) def test_readonly_list_and_retrieve(self): organization = OrganizationFactory() OrganizationFactory.create_batch(2) list_response = self.client.get("/organizations-readonly/") self.assertEqual(list_response.status_code, status.HTTP_200_OK) self.assertTrue(list_response.data["success"]) detail_response = self.client.get(f"/organizations-readonly/{organization.pk}/") self.assertEqual(detail_response.status_code, status.HTTP_200_OK) self.assertEqual(detail_response.data["data"]["id"], str(organization.pk)) @override_settings(ROOT_URLCONF=__name__) class OwnerViewSetIntegrationTest(APITestCase): def setUp(self): self.user = UserFactory.create_user() self.other_user = UserFactory.create_user() self.client.force_authenticate(self.user) def test_list_filters_by_owner(self): ProfileFactory.create_profile(user=self.user) ProfileFactory.create_profile(user=self.other_user) response = self.client.get("/profiles-owner/") self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(len(response.data["data"]), 1) def test_create_sets_owner(self): user = UserFactory.create_user() user.profile.delete() self.client.force_authenticate(user) response = self.client.post( "/profiles-owner/", {"first_name": fake.first_name(), "last_name": fake.last_name()}, format="json", ) self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(response.data["data"]["user"], user.id) @override_settings(ROOT_URLCONF=__name__) class BulkMixinIntegrationTest(APITestCase): def setUp(self): self.user = UserFactory.create_user() self.client.force_authenticate(self.user) def test_bulk_create_empty_items(self): response = self.client.post( "/bulk-organizations/bulk_create/", {}, format="json" ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertFalse(response.data["success"]) def test_bulk_create_too_many(self): items = [_organization_payload() for _ in range(3)] response = self.client.post( "/bulk-organizations/bulk_create/", {"items": items}, format="json" ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(response.data["errors"][0]["code"], "too_many_items") def test_bulk_create_update_delete(self): items = [_organization_payload(), _organization_payload()] created = self.client.post( "/bulk-organizations/bulk_create/", {"items": items}, format="json" ) self.assertEqual(created.status_code, status.HTTP_201_CREATED) created_ids = [item["id"] for item in created.data["data"]] update_items = [ {"id": created_ids[0], "name": fake.company()}, { "id": fake.uuid4(), "name": fake.company(), }, ] updated = self.client.patch( "/bulk-organizations/bulk_update/", {"items": update_items}, format="json" ) self.assertEqual(updated.status_code, status.HTTP_200_OK) self.assertEqual(len(updated.data["data"]["updated"]), 1) self.assertEqual(len(updated.data["data"]["errors"]), 1) deleted = self.client.delete( "/bulk-organizations/bulk_delete/", {"ids": created_ids}, format="json" ) self.assertEqual(deleted.status_code, status.HTTP_200_OK) self.assertEqual(deleted.data["data"]["deleted"], len(created_ids)) def test_bulk_update_missing_ids(self): response = self.client.patch( "/bulk-organizations/bulk_update/", {"items": [{"name": fake.company()}]}, format="json", ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(response.data["errors"][0]["code"], "missing_ids") @override_settings(ROOT_URLCONF=__name__) class QuerysetOptimizationIntegrationTest(APITestCase): def setUp(self): self.user = UserFactory.create_user() self.client.force_authenticate(self.user) def test_select_related_and_old_style(self): ProfileFactory.create_profile(user=self.user) response_new = self.client.get("/profiles-select/") self.assertEqual(response_new.status_code, status.HTTP_200_OK) response_old = self.client.get("/profiles-old/") self.assertEqual(response_old.status_code, status.HTTP_200_OK) def test_prefetch_related_and_old_style(self): UserFactory.create_user() response_new = self.client.get("/users-prefetch/") self.assertEqual(response_new.status_code, status.HTTP_200_OK) response_old = self.client.get("/users-old/") self.assertEqual(response_old.status_code, status.HTTP_200_OK) def test_defer_fields_branch(self): OrganizationFactory.create_batch(2) response = self.client.get("/organizations-defer/") self.assertEqual(response.status_code, status.HTTP_200_OK)