"""Tests for exchange API views.""" import json from types import SimpleNamespace from unittest.mock import patch from apps.exchange.models import ExchangeConnection from apps.exchange.services import ExchangePeriodicTaskService, ExchangeServiceError from django.conf import settings from django.urls import reverse from django_celery_beat.models import IntervalSchedule, PeriodicTask from rest_framework import status from rest_framework.test import APITestCase from tests.apps.exchange.factories import ExchangeConnectionFactory from tests.apps.user.factories import UserFactory class ExchangeViewsTest(APITestCase): def setUp(self): self.user = UserFactory.create_user() self.admin = UserFactory.create_superuser() self.connections_url = reverse("api_v1:exchange:connections") self.test_connection_url = reverse("api_v1:exchange:connections-test") self.copy_url = reverse("api_v1:exchange:copy") self.periodic_tasks_url = reverse("api_v1:exchange:periodic-tasks") def test_connections_endpoint_admin_only(self): response = self.client.get(self.connections_url) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) self.client.force_authenticate(self.user) response = self.client.get(self.connections_url) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.client.force_authenticate(self.admin) response = self.client.get(self.connections_url) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertTrue(response.data["success"]) self.assertIsInstance(response.data["data"], list) @patch("apps.exchange.services.ExchangeConnectionService.validate_target_structure") @patch("apps.exchange.services.ExchangeConnectionService.test_connection") def test_create_connection_success(self, connection_mock, validate_mock): old_active = ExchangeConnectionFactory(is_active=True) payload = { "server": "127.0.0.1", "port": 5432, "username": "postgres", "password": "secret", "database_name": "target_db", "schema_name": "public", } self.client.force_authenticate(self.admin) response = self.client.post(self.connections_url, payload, format="json") self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(ExchangeConnection.objects.filter(is_active=True).count(), 1) new_connection = ExchangeConnection.objects.get(id=response.data["data"]["id"]) self.assertTrue(new_connection.is_active) self.assertNotEqual(new_connection.password, payload["password"]) self.assertEqual(new_connection.get_decrypted_password(), payload["password"]) old_active.refresh_from_db() self.assertFalse(old_active.is_active) connection_mock.assert_called_once() validate_mock.assert_called_once() @patch("apps.exchange.services.ExchangeConnectionService.test_connection_payload") def test_test_connection_success(self, test_connection_mock): payload = { "server": "127.0.0.1", "port": 5432, "username": "postgres", "password": "secret", "database_name": "target_db", "schema_name": "public", } test_connection_mock.return_value = { "status": "success", "message": "ok", } self.client.force_authenticate(self.admin) response = self.client.post(self.test_connection_url, payload, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data["data"]["status"], "success") self.assertEqual(ExchangeConnection.objects.count(), 0) test_connection_mock.assert_called_once_with(**payload) @patch("apps.exchange.services.ExchangeConnectionService.test_connection_payload") def test_test_connection_failure_returns_400(self, test_connection_mock): payload = { "server": "127.0.0.1", "port": 5432, "username": "postgres", "password": "secret", "database_name": "target_db", "schema_name": "public", } test_connection_mock.side_effect = ExchangeServiceError("Connection refused") self.client.force_authenticate(self.admin) response = self.client.post(self.test_connection_url, payload, format="json") self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(ExchangeConnection.objects.count(), 0) @patch("apps.exchange.services.ExchangeConnectionService.test_connection") def test_create_connection_fail_rolls_back_active(self, connection_mock): connection_mock.side_effect = ExchangeServiceError("Connection refused") old_active = ExchangeConnectionFactory(is_active=True) payload = { "server": "127.0.0.1", "port": 5432, "username": "postgres", "password": "secret", "database_name": "target_db", "schema_name": "public", } self.client.force_authenticate(self.admin) response = self.client.post(self.connections_url, payload, format="json") self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(ExchangeConnection.objects.count(), 1) old_active.refresh_from_db() self.assertTrue(old_active.is_active) def test_copy_requires_active_connection(self): self.client.force_authenticate(self.admin) response = self.client.post(self.copy_url, {"mode": "all"}, format="json") self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) @patch("apps.exchange.views.copy_parsers_data_async.delay") @patch("apps.exchange.services.ExchangeConnectionService.get_active_connection") def test_copy_all_success(self, get_active_mock, delay_mock): active_connection = ExchangeConnectionFactory(is_active=True) get_active_mock.return_value = active_connection delay_mock.return_value = SimpleNamespace(id="task-123") self.client.force_authenticate(self.admin) response = self.client.post(self.copy_url, {"mode": "all"}, format="json") self.assertEqual(response.status_code, status.HTTP_202_ACCEPTED) self.assertEqual(response.data["data"]["status"], "started") self.assertEqual(response.data["data"]["task_id"], "task-123") self.assertEqual(response.data["data"]["connection_id"], active_connection.id) get_active_mock.assert_called_once() delay_mock.assert_called_once_with( connection_id=active_connection.id, payload={"mode": "all", "truncate_before_copy": True}, requested_by_id=self.admin.id, ) def test_copy_single_requires_table(self): self.client.force_authenticate(self.admin) response = self.client.post(self.copy_url, {"mode": "single"}, format="json") self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertIn("table", str(response.data)) def test_periodic_tasks_endpoint_admin_only(self): response = self.client.get(self.periodic_tasks_url) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) self.client.force_authenticate(self.user) response = self.client.get(self.periodic_tasks_url) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.client.force_authenticate(self.admin) response = self.client.get(self.periodic_tasks_url) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertTrue(response.data["success"]) self.assertEqual(response.data["data"], []) def test_create_periodic_interval_task_success(self): payload = { "name": "exchange-copy-hourly", "description": "Hourly sync", "enabled": True, "schedule_type": "interval", "interval_every": 1, "interval_period": "hours", "mode": "all", } self.client.force_authenticate(self.admin) response = self.client.post(self.periodic_tasks_url, payload, format="json") self.assertEqual(response.status_code, status.HTTP_201_CREATED) task = PeriodicTask.objects.get(id=response.data["data"]["id"]) self.assertEqual(task.task, ExchangePeriodicTaskService.TASK_NAME) self.assertEqual(response.data["data"]["schedule_type"], "interval") self.assertEqual(response.data["data"]["interval_every"], 1) self.assertEqual(response.data["data"]["interval_period"], "hours") self.assertEqual( json.loads(task.kwargs), { "payload": { "mode": "all", "table": None, "tables": None, "truncate_before_copy": True, } }, ) def test_list_periodic_tasks_returns_only_exchange_tasks(self): interval = IntervalSchedule.objects.create(every=1, period="hours") PeriodicTask.objects.create( name="exchange-copy-hourly", task=ExchangePeriodicTaskService.TASK_NAME, interval=interval, kwargs=json.dumps( {"payload": {"mode": "all", "truncate_before_copy": True}} ), ) PeriodicTask.objects.create( name="another-task", task="apps.parsers.tasks.fake_task", interval=interval, kwargs="{}", ) self.client.force_authenticate(self.admin) response = self.client.get(self.periodic_tasks_url) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(len(response.data["data"]), 1) self.assertEqual(response.data["data"][0]["name"], "exchange-copy-hourly") def test_update_periodic_task_switches_to_crontab(self): interval = IntervalSchedule.objects.create(every=1, period="hours") task = PeriodicTask.objects.create( name="exchange-copy-hourly", task=ExchangePeriodicTaskService.TASK_NAME, description="Hourly sync", enabled=True, interval=interval, kwargs=json.dumps( { "payload": { "mode": "all", "table": None, "tables": None, "truncate_before_copy": True, } } ), ) detail_url = reverse( "api_v1:exchange:periodic-task-detail", kwargs={"task_id": task.id}, ) payload = { "schedule_type": "crontab", "crontab_minute": "0", "crontab_hour": "4", "crontab_day_of_week": "*", "crontab_day_of_month": "*", "crontab_month_of_year": "*", "mode": "single", "table": "parsers_proxy", "enabled": False, } self.client.force_authenticate(self.admin) response = self.client.patch(detail_url, payload, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) task.refresh_from_db() self.assertIsNone(task.interval) self.assertIsNotNone(task.crontab) self.assertEqual(str(task.crontab.timezone), settings.TIME_ZONE) self.assertFalse(task.enabled) self.assertEqual(response.data["data"]["schedule_type"], "crontab") self.assertEqual(response.data["data"]["crontab_hour"], "4") self.assertEqual(response.data["data"]["mode"], "single") self.assertEqual(response.data["data"]["table"], "parsers_proxy") self.assertFalse(IntervalSchedule.objects.filter(id=interval.id).exists()) def test_periodic_task_detail_returns_404_for_non_exchange_task(self): interval = IntervalSchedule.objects.create(every=1, period="hours") task = PeriodicTask.objects.create( name="another-task", task="apps.parsers.tasks.fake_task", interval=interval, kwargs="{}", ) detail_url = reverse( "api_v1:exchange:periodic-task-detail", kwargs={"task_id": task.id}, ) self.client.force_authenticate(self.admin) response = self.client.get(detail_url) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)