Some checks failed
CI/CD Pipeline / Code Quality Checks (pull_request) Failing after 2m39s
CI/CD Pipeline / Run Tests (pull_request) Successful in 3m0s
CI/CD Pipeline / Run API Inventory E2E Tests (pull_request) Successful in 35s
CI/CD Pipeline / Telegram Notify Success (pull_request) Has been skipped
330 lines
13 KiB
Python
330 lines
13 KiB
Python
"""Tests for exchange API views."""
|
|
|
|
import json
|
|
from types import SimpleNamespace
|
|
from unittest.mock import patch
|
|
|
|
from apps.exchange.models import ExchangeConnection
|
|
from apps.exchange.services import ExchangePeriodicTaskService, ExchangeServiceError
|
|
from django.conf import settings
|
|
from django.urls import reverse
|
|
from django_celery_beat.models import IntervalSchedule, PeriodicTask
|
|
from rest_framework import status
|
|
from rest_framework.test import APITestCase
|
|
|
|
from tests.apps.exchange.factories import ExchangeConnectionFactory
|
|
from tests.apps.user.factories import UserFactory
|
|
|
|
|
|
class ExchangeViewsTest(APITestCase):
|
|
def setUp(self):
|
|
self.user = UserFactory.create_user()
|
|
self.admin = UserFactory.create_superuser()
|
|
self.connections_url = reverse("api_v1:exchange:connections")
|
|
self.test_connection_url = reverse("api_v1:exchange:connections-test")
|
|
self.copy_url = reverse("api_v1:exchange:copy")
|
|
self.periodic_tasks_url = reverse("api_v1:exchange:periodic-tasks")
|
|
|
|
def test_connections_endpoint_admin_only(self):
|
|
response = self.client.get(self.connections_url)
|
|
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
|
|
|
|
self.client.force_authenticate(self.user)
|
|
response = self.client.get(self.connections_url)
|
|
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
|
|
|
|
self.client.force_authenticate(self.admin)
|
|
response = self.client.get(self.connections_url)
|
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
|
self.assertIsInstance(response.data["results"], list)
|
|
|
|
@patch("apps.exchange.services.ExchangeConnectionService.validate_target_structure")
|
|
@patch("apps.exchange.services.ExchangeConnectionService.prepare_target_structure")
|
|
@patch("apps.exchange.services.ExchangeConnectionService.test_connection")
|
|
def test_create_connection_success(
|
|
self,
|
|
connection_mock,
|
|
prepare_mock,
|
|
validate_mock,
|
|
):
|
|
old_active = ExchangeConnectionFactory(is_active=True)
|
|
|
|
payload = {
|
|
"server": "127.0.0.1",
|
|
"port": 5432,
|
|
"username": "postgres",
|
|
"password": "secret",
|
|
"database_name": "target_db",
|
|
"schema_name": "public",
|
|
}
|
|
|
|
self.client.force_authenticate(self.admin)
|
|
response = self.client.post(self.connections_url, payload, format="json")
|
|
|
|
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
|
|
self.assertEqual(ExchangeConnection.objects.filter(is_active=True).count(), 1)
|
|
self.assertEqual(
|
|
set(response.data.keys()),
|
|
{
|
|
"id",
|
|
"server",
|
|
"port",
|
|
"username",
|
|
"database_name",
|
|
"schema_name",
|
|
"is_active",
|
|
},
|
|
)
|
|
|
|
new_connection = ExchangeConnection.objects.get(id=response.data["id"])
|
|
self.assertTrue(new_connection.is_active)
|
|
self.assertNotEqual(new_connection.password, payload["password"])
|
|
self.assertEqual(new_connection.get_decrypted_password(), payload["password"])
|
|
|
|
old_active.refresh_from_db()
|
|
self.assertFalse(old_active.is_active)
|
|
|
|
connection_mock.assert_called_once()
|
|
prepare_mock.assert_called_once()
|
|
validate_mock.assert_called_once()
|
|
|
|
@patch("apps.exchange.services.ExchangeConnectionService.test_connection_payload")
|
|
def test_test_connection_success(self, test_connection_mock):
|
|
payload = {
|
|
"server": "127.0.0.1",
|
|
"port": 5432,
|
|
"username": "postgres",
|
|
"password": "secret",
|
|
"database_name": "target_db",
|
|
"schema_name": "public",
|
|
}
|
|
test_connection_mock.return_value = {
|
|
"status": "success",
|
|
"message": "ok",
|
|
}
|
|
|
|
self.client.force_authenticate(self.admin)
|
|
response = self.client.post(self.test_connection_url, payload, format="json")
|
|
|
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
|
self.assertTrue(response.data["success"])
|
|
self.assertEqual(response.data["message"], "ok")
|
|
self.assertEqual(ExchangeConnection.objects.count(), 0)
|
|
test_connection_mock.assert_called_once_with(**payload)
|
|
|
|
@patch("apps.exchange.services.ExchangeConnectionService.test_connection_payload")
|
|
def test_test_connection_failure_returns_400(self, test_connection_mock):
|
|
payload = {
|
|
"server": "127.0.0.1",
|
|
"port": 5432,
|
|
"username": "postgres",
|
|
"password": "secret",
|
|
"database_name": "target_db",
|
|
"schema_name": "public",
|
|
}
|
|
test_connection_mock.side_effect = ExchangeServiceError("Connection refused")
|
|
|
|
self.client.force_authenticate(self.admin)
|
|
response = self.client.post(self.test_connection_url, payload, format="json")
|
|
|
|
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
|
self.assertEqual(ExchangeConnection.objects.count(), 0)
|
|
|
|
@patch("apps.exchange.services.ExchangeConnectionService.test_connection")
|
|
def test_create_connection_fail_rolls_back_active(self, connection_mock):
|
|
connection_mock.side_effect = ExchangeServiceError("Connection refused")
|
|
|
|
old_active = ExchangeConnectionFactory(is_active=True)
|
|
|
|
payload = {
|
|
"server": "127.0.0.1",
|
|
"port": 5432,
|
|
"username": "postgres",
|
|
"password": "secret",
|
|
"database_name": "target_db",
|
|
"schema_name": "public",
|
|
}
|
|
|
|
self.client.force_authenticate(self.admin)
|
|
response = self.client.post(self.connections_url, payload, format="json")
|
|
|
|
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
|
self.assertEqual(ExchangeConnection.objects.count(), 1)
|
|
|
|
old_active.refresh_from_db()
|
|
self.assertTrue(old_active.is_active)
|
|
|
|
def test_copy_requires_active_connection(self):
|
|
self.client.force_authenticate(self.admin)
|
|
response = self.client.post(self.copy_url, {"mode": "all"}, format="json")
|
|
|
|
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
|
|
|
@patch("apps.exchange.views.copy_parsers_data_async.delay")
|
|
@patch("apps.exchange.services.ExchangeConnectionService.get_active_connection")
|
|
def test_copy_all_success(self, get_active_mock, delay_mock):
|
|
active_connection = ExchangeConnectionFactory(is_active=True)
|
|
get_active_mock.return_value = active_connection
|
|
delay_mock.return_value = SimpleNamespace(id="task-123")
|
|
|
|
self.client.force_authenticate(self.admin)
|
|
response = self.client.post(self.copy_url, {"mode": "all"}, format="json")
|
|
|
|
self.assertEqual(response.status_code, status.HTTP_202_ACCEPTED)
|
|
self.assertEqual(response.data["data"]["status"], "started")
|
|
self.assertEqual(response.data["data"]["task_id"], "task-123")
|
|
self.assertEqual(response.data["data"]["connection_id"], active_connection.id)
|
|
get_active_mock.assert_called_once()
|
|
delay_mock.assert_called_once_with(
|
|
connection_id=active_connection.id,
|
|
payload={"mode": "all", "truncate_before_copy": True},
|
|
requested_by_id=self.admin.id,
|
|
)
|
|
|
|
def test_copy_single_requires_table(self):
|
|
self.client.force_authenticate(self.admin)
|
|
response = self.client.post(self.copy_url, {"mode": "single"}, format="json")
|
|
|
|
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
|
|
self.assertIn("table", str(response.data))
|
|
|
|
def test_periodic_tasks_endpoint_admin_only(self):
|
|
response = self.client.get(self.periodic_tasks_url)
|
|
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
|
|
|
|
self.client.force_authenticate(self.user)
|
|
response = self.client.get(self.periodic_tasks_url)
|
|
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
|
|
|
|
self.client.force_authenticate(self.admin)
|
|
response = self.client.get(self.periodic_tasks_url)
|
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
|
self.assertEqual(response.data["results"], [])
|
|
|
|
def test_create_periodic_interval_task_success(self):
|
|
payload = {
|
|
"schedule_type": "interval",
|
|
"interval_every": 1,
|
|
"interval_period": "hours",
|
|
"notify_on_error": True,
|
|
}
|
|
|
|
self.client.force_authenticate(self.admin)
|
|
response = self.client.post(self.periodic_tasks_url, payload, format="json")
|
|
|
|
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
|
|
task = PeriodicTask.objects.get(id=response.data["id"])
|
|
self.assertEqual(task.task, ExchangePeriodicTaskService.TASK_NAME)
|
|
self.assertEqual(
|
|
set(response.data.keys()),
|
|
{
|
|
"id",
|
|
"schedule_type",
|
|
"interval_every",
|
|
"interval_period",
|
|
"notify_on_error",
|
|
},
|
|
)
|
|
self.assertEqual(response.data["schedule_type"], "interval")
|
|
self.assertEqual(response.data["interval_every"], 1)
|
|
self.assertEqual(response.data["interval_period"], "hours")
|
|
self.assertTrue(response.data["notify_on_error"])
|
|
self.assertEqual(
|
|
json.loads(task.kwargs),
|
|
{
|
|
"payload": {
|
|
"mode": "all",
|
|
"table": None,
|
|
"tables": None,
|
|
"truncate_before_copy": True,
|
|
"notify_on_error": True,
|
|
}
|
|
},
|
|
)
|
|
|
|
def test_list_periodic_tasks_returns_only_exchange_tasks(self):
|
|
interval = IntervalSchedule.objects.create(every=1, period="hours")
|
|
PeriodicTask.objects.create(
|
|
name="exchange-copy-hourly",
|
|
task=ExchangePeriodicTaskService.TASK_NAME,
|
|
interval=interval,
|
|
kwargs=json.dumps(
|
|
{"payload": {"mode": "all", "truncate_before_copy": True}}
|
|
),
|
|
)
|
|
PeriodicTask.objects.create(
|
|
name="another-task",
|
|
task="apps.parsers.tasks.fake_task",
|
|
interval=interval,
|
|
kwargs="{}",
|
|
)
|
|
|
|
self.client.force_authenticate(self.admin)
|
|
response = self.client.get(self.periodic_tasks_url)
|
|
|
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
|
self.assertEqual(len(response.data["results"]), 1)
|
|
self.assertEqual(response.data["results"][0]["schedule_type"], "interval")
|
|
|
|
def test_update_periodic_task_switches_to_crontab(self):
|
|
interval = IntervalSchedule.objects.create(every=1, period="hours")
|
|
task = PeriodicTask.objects.create(
|
|
name="exchange-copy-hourly",
|
|
task=ExchangePeriodicTaskService.TASK_NAME,
|
|
description="Hourly sync",
|
|
enabled=True,
|
|
interval=interval,
|
|
kwargs=json.dumps(
|
|
{
|
|
"payload": {
|
|
"mode": "all",
|
|
"table": None,
|
|
"tables": None,
|
|
"truncate_before_copy": True,
|
|
}
|
|
}
|
|
),
|
|
)
|
|
detail_url = reverse(
|
|
"api_v1:exchange:periodic-task-detail",
|
|
kwargs={"task_id": task.id},
|
|
)
|
|
payload = {
|
|
"schedule_type": "daily",
|
|
"crontab_minute": 0,
|
|
"crontab_hour": 4,
|
|
"notify_on_error": True,
|
|
}
|
|
|
|
self.client.force_authenticate(self.admin)
|
|
response = self.client.patch(detail_url, payload, format="json")
|
|
|
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
|
task.refresh_from_db()
|
|
self.assertIsNone(task.interval)
|
|
self.assertIsNotNone(task.crontab)
|
|
self.assertEqual(str(task.crontab.timezone), settings.TIME_ZONE)
|
|
self.assertTrue(task.enabled)
|
|
self.assertEqual(response.data["schedule_type"], "daily")
|
|
self.assertEqual(response.data["crontab_hour"], 4)
|
|
self.assertTrue(response.data["notify_on_error"])
|
|
self.assertFalse(IntervalSchedule.objects.filter(id=interval.id).exists())
|
|
|
|
def test_periodic_task_detail_returns_404_for_non_exchange_task(self):
|
|
interval = IntervalSchedule.objects.create(every=1, period="hours")
|
|
task = PeriodicTask.objects.create(
|
|
name="another-task",
|
|
task="apps.parsers.tasks.fake_task",
|
|
interval=interval,
|
|
kwargs="{}",
|
|
)
|
|
detail_url = reverse(
|
|
"api_v1:exchange:periodic-task-detail",
|
|
kwargs={"task_id": task.id},
|
|
)
|
|
|
|
self.client.force_authenticate(self.admin)
|
|
response = self.client.get(detail_url)
|
|
|
|
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
|