"""Tests for core Celery tasks""" import logging from io import StringIO from apps.core.tasks import ( BaseTask, IdempotentTask, PeriodicTask, TimedTask, TransactionalTask, ) from celery import Task from core.celery import app as celery_app from django.test import TestCase @celery_app.task(base=BaseTask, bind=True) def base_task(self, marker: str): return marker @celery_app.task(base=TransactionalTask, bind=True) def transactional_task(self, marker: str): from apps.core.models import BackgroundJob BackgroundJob.objects.create(task_id=marker, task_name="test.tx") if marker == "fail": raise ValueError("boom") return marker @celery_app.task(base=IdempotentTask, bind=True) def idempotent_task(self, marker: str): from apps.core.models import BackgroundJob BackgroundJob.objects.create(task_id=marker, task_name="test.idem") return marker @celery_app.task(base=TimedTask, bind=True) def timed_task(self, marker: str): return marker class BaseTaskTest(TestCase): """Tests for BaseTask""" def test_inherits_from_celery_task(self): """Test BaseTask inherits from Celery Task""" self.assertTrue(issubclass(BaseTask, Task)) def test_has_default_retry_settings(self): """Test BaseTask has default retry settings""" self.assertEqual(BaseTask.max_retries, 3) self.assertTrue(BaseTask.retry_backoff) self.assertEqual(BaseTask.retry_backoff_max, 600) def test_acks_late_enabled(self): """Test acks_late is enabled""" self.assertTrue(BaseTask.acks_late) def test_reject_on_worker_lost(self): """Test reject_on_worker_lost is enabled""" self.assertTrue(BaseTask.reject_on_worker_lost) class TransactionalTaskTest(TestCase): """Tests for TransactionalTask""" def test_inherits_from_base_task(self): """Test TransactionalTask inherits from BaseTask""" self.assertTrue(issubclass(TransactionalTask, BaseTask)) class IdempotentTaskTest(TestCase): """Tests for IdempotentTask""" def test_inherits_from_base_task(self): """Test IdempotentTask inherits from BaseTask""" self.assertTrue(issubclass(IdempotentTask, BaseTask)) def test_has_lock_timeout(self): """Test IdempotentTask has lock_timeout attribute""" self.assertEqual(IdempotentTask.lock_timeout, 3600) class TimedTaskTest(TestCase): """Tests for TimedTask""" def test_inherits_from_base_task(self): """Test TimedTask inherits from BaseTask""" self.assertTrue(issubclass(TimedTask, BaseTask)) def test_has_slow_threshold(self): """Test TimedTask has slow_threshold attribute""" self.assertEqual(TimedTask.slow_threshold, 60) class PeriodicTaskTest(TestCase): """Tests for PeriodicTask""" def test_inherits_from_timed_task(self): """Test PeriodicTask inherits from TimedTask""" self.assertTrue(issubclass(PeriodicTask, TimedTask)) def test_max_retries_is_one(self): """Test max_retries is 1 for periodic tasks""" self.assertEqual(PeriodicTask.max_retries, 1) def test_autoretry_for_is_empty(self): """Test autoretry_for is empty for periodic tasks""" self.assertEqual(PeriodicTask.autoretry_for, ()) class TaskRuntimeBehaviorTest(TestCase): def setUp(self): self.logger = logging.getLogger("apps.core.tasks") self.logger.setLevel(logging.INFO) self.stream = StringIO() handler = logging.StreamHandler(self.stream) self.logger.handlers = [] self.logger.addHandler(handler) def test_base_task_hooks(self): base_task.apply(args=("ok",)).get() base_task.request_stack.push(type("Req", (), {"retries": 1})()) base_task.on_retry(Exception("retry"), "id-1", (), {}, None) base_task.request_stack.pop() output = self.stream.getvalue() self.assertIn("base_task", output) def test_transactional_task_rolls_back(self): from apps.core.models import BackgroundJob with self.assertRaises(ValueError): transactional_task.apply(args=("fail",)).get() self.assertFalse(BackgroundJob.objects.filter(task_id="fail").exists()) def test_transactional_task_commits(self): from apps.core.models import BackgroundJob transactional_task.apply(args=("ok",)).get() self.assertTrue(BackgroundJob.objects.filter(task_id="ok").exists()) def test_idempotent_task_skips_second_call(self): from apps.core.models import BackgroundJob idempotent_task.apply(args=("idem",)).get() idempotent_task.apply(args=("idem",)).get() self.assertEqual(BackgroundJob.objects.filter(task_id="idem").count(), 1) def test_timed_task_logs_warning(self): timed_task.slow_threshold = 0 result = timed_task.apply(args=("payload",)).get() self.assertEqual(result, "payload") output = self.stream.getvalue() self.assertIn("timed_task", output)