diff --git a/src/organizations/api_enrichment.py b/src/organizations/api_enrichment.py index 33eb90b..e75542e 100644 --- a/src/organizations/api_enrichment.py +++ b/src/organizations/api_enrichment.py @@ -21,6 +21,7 @@ from apps.parsers.models import ( from django.db.models import Count, Prefetch, Q from registers.models import RegistryMembershipPeriod +from organizations.data_sources import to_api_data_source, to_internal_data_source from organizations.models import Organization GENERIC_SOURCES = ( @@ -45,12 +46,7 @@ DATA_PRESENCE_KEYS = ( ParserLoadLog.Source.FNS_REPORTS, ) DATA_PRESENCE_KEY_SET = {str(source) for source in DATA_PRESENCE_KEYS} -API_DATA_SOURCE_ALIASES = { - ParserLoadLog.Source.TRUDVSEM: "vacancies", -} -API_DATA_SOURCE_KEY_SET = { - API_DATA_SOURCE_ALIASES.get(source, str(source)) for source in DATA_PRESENCE_KEYS -} +API_DATA_SOURCE_KEY_SET = {to_api_data_source(source) for source in DATA_PRESENCE_KEYS} @dataclass(frozen=True) @@ -98,19 +94,6 @@ def data_presence_identity_values(source: str) -> tuple[set[str], set[str]]: return matches["inn"], matches["ogrn"] -def to_api_data_source(source: str) -> str: - """Return v2 public data source key for an internal parser source.""" - return API_DATA_SOURCE_ALIASES.get(source, str(source)) - - -def to_internal_data_source(source: str) -> str: - """Return internal parser source key from a v2 public key.""" - for internal_source, api_source in API_DATA_SOURCE_ALIASES.items(): - if source == api_source: - return str(internal_source) - return source - - def _source_matches(source: str) -> dict[str, set[str]]: if source == ParserLoadLog.Source.INDUSTRIAL: return OrganizationApiEnrichmentService._matching_identifiers_for_all( diff --git a/src/organizations/data_sources.py b/src/organizations/data_sources.py new file mode 100644 index 0000000..f3e3183 --- /dev/null +++ b/src/organizations/data_sources.py @@ -0,0 +1,48 @@ +"""Helpers for organization API data source keys and summaries.""" + +from __future__ import annotations + +from typing import Any + +API_DATA_SOURCE_ALIASES = { + "trudvsem": "vacancies", +} + + +def to_api_data_source(source: str) -> str: + """Return v2 public data source key for an internal parser source.""" + return API_DATA_SOURCE_ALIASES.get(str(source), str(source)) + + +def to_internal_data_source(source: str) -> str: + """Return internal parser source key from a v2 public key.""" + for internal_source, api_source in API_DATA_SOURCE_ALIASES.items(): + if source == api_source: + return internal_source + return source + + +def snapshot_data_with_api_keys(data: dict[str, Any]) -> dict[str, Any]: + """Return snapshot data keyed by public API source names.""" + return {to_api_data_source(source): value for source, value in data.items()} + + +def data_source_summary(data: dict[str, Any]) -> list[dict[str, int | str]]: + """Return non-empty source counters for a data payload.""" + summary: list[dict[str, int | str]] = [] + for source in sorted(data): + value = data[source] + if isinstance(value, list): + count = len(value) + elif value: + count = 1 + else: + count = 0 + if count: + summary.append({"source": source, "count": count}) + return summary + + +def snapshot_data_source_summary(data: dict[str, Any]) -> list[dict[str, int | str]]: + """Return non-empty public source counters for stored snapshot data.""" + return data_source_summary(snapshot_data_with_api_keys(data)) diff --git a/src/organizations/filters.py b/src/organizations/filters.py index 9059d38..2758221 100644 --- a/src/organizations/filters.py +++ b/src/organizations/filters.py @@ -1,6 +1,7 @@ """Filters for organizations API v2.""" -from django.db.models import CharField, Exists, OuterRef, Q +from apps.parsers.models import FinancialReport, ParserLoadLog +from django.db.models import CharField, Q from django.db.models.functions import Cast from django_filters import rest_framework as filters from registers.models import RegistryMembershipPeriod @@ -84,6 +85,9 @@ class OrganizationFilter(filters.FilterSet): if source not in DATA_PRESENCE_KEYS: return queryset.none() + if source == ParserLoadLog.Source.FNS_REPORTS: + return self._filter_by_fns_report_presence(queryset, value) + inn_values, ogrn_values = data_presence_identity_values(source) filtered = self._filter_by_registry_identities( queryset, inn_values, ogrn_values @@ -106,6 +110,17 @@ class OrganizationFilter(filters.FilterSet): query |= Q(ogrn__in=ogrn_values) | Q(ogrip__in=ogrn_values) return queryset.filter(query) + @staticmethod + def _filter_by_fns_report_presence(queryset, value): + report_ogrns = FinancialReport.objects.order_by().values_list( + "ogrn", + flat=True, + ) + query = Q(ogrn__in=report_ogrns) | Q(ogrip__in=report_ogrns) + if value: + return queryset.filter(query) + return queryset.exclude(query) + @classmethod def _filter_by_registry_membership( cls, @@ -115,23 +130,44 @@ class OrganizationFilter(filters.FilterSet): registry_name: str | None = None, has_registry: bool = True, ): - membership = cls._registry_membership_subquery( + query = cls._registry_identity_query( registry_id=registry_id, registry_name=registry_name, ) - return queryset.annotate(_has_registry=Exists(membership)).filter( - _has_registry=has_registry + if has_registry: + return queryset.filter(query) + return queryset.exclude(query) + + @classmethod + def _registry_identity_query( + cls, + *, + registry_id: str | None = None, + registry_name: str | None = None, + ): + inn_values, ogrn_values = cls._registry_identity_value_querysets( + registry_id=registry_id, + registry_name=registry_name, + ) + return ( + Q(inn__in=inn_values) | Q(ogrn__in=ogrn_values) | Q(ogrip__in=ogrn_values) ) @staticmethod - def _registry_membership_subquery( + def _registry_identity_value_querysets( *, registry_id: str | None = None, registry_name: str | None = None, ): membership = RegistryMembershipPeriod.objects.filter( ended_at__isnull=True, - ).annotate( + ).order_by() + if registry_id: + membership = membership.filter(registry_id=registry_id) + if registry_name: + membership = membership.filter(registry__name__icontains=registry_name) + + membership = membership.annotate( organization_inn_text=Cast( "organization__mn_inn", output_field=CharField() ), @@ -139,13 +175,8 @@ class OrganizationFilter(filters.FilterSet): "organization__mn_ogrn", output_field=CharField() ), ) - if registry_id: - membership = membership.filter(registry_id=registry_id) - if registry_name: - membership = membership.filter(registry__name__icontains=registry_name) - return membership.filter( - Q(organization_inn_text=OuterRef("inn")) - | Q(organization_ogrn_text=OuterRef("ogrn")) - | Q(organization_ogrn_text=OuterRef("ogrip")) + return ( + membership.values_list("organization_inn_text", flat=True), + membership.values_list("organization_ogrn_text", flat=True), ) diff --git a/src/organizations/migrations/0005_snapshot_data_source_counts.py b/src/organizations/migrations/0005_snapshot_data_source_counts.py new file mode 100644 index 0000000..0815593 --- /dev/null +++ b/src/organizations/migrations/0005_snapshot_data_source_counts.py @@ -0,0 +1,119 @@ +from django.db import migrations, models + + +API_DATA_SOURCE_ALIASES = { + "trudvsem": "vacancies", +} + + +def to_api_data_source(source): + return API_DATA_SOURCE_ALIASES.get(str(source), str(source)) + + +def data_source_summary(data): + summary = [] + for source in sorted(data): + value = data[source] + if isinstance(value, list): + count = len(value) + elif value: + count = 1 + else: + count = 0 + if count: + summary.append({"source": to_api_data_source(source), "count": count}) + return summary + + +def backfill_data_source_counts_python(apps): + snapshot_model = apps.get_model("organizations", "OrganizationDataSnapshot") + updates = [] + + for snapshot in snapshot_model.objects.only( + "organization_id", + "data", + "data_source_counts", + ).iterator(chunk_size=100): + snapshot.data_source_counts = data_source_summary(snapshot.data) + updates.append(snapshot) + if len(updates) >= 100: + snapshot_model.objects.bulk_update(updates, ["data_source_counts"]) + updates = [] + + if updates: + snapshot_model.objects.bulk_update(updates, ["data_source_counts"]) + + +def backfill_data_source_counts(apps, schema_editor): + if schema_editor.connection.vendor != "postgresql": + backfill_data_source_counts_python(apps) + return + + with schema_editor.connection.cursor() as cursor: + cursor.execute( + """ + UPDATE organizations_data_snapshot snapshot + SET data_source_counts = COALESCE( + ( + SELECT jsonb_agg( + jsonb_build_object( + 'source', + source_counts.source, + 'count', + source_counts.record_count + ) + ORDER BY source_counts.source + ) + FROM ( + SELECT CASE source_items.key + WHEN 'trudvsem' THEN 'vacancies' + ELSE source_items.key + END AS source, + CASE + WHEN jsonb_typeof(source_items.value) = 'array' + THEN jsonb_array_length(source_items.value) + WHEN source_items.value IN ( + 'null'::jsonb, + 'false'::jsonb, + '[]'::jsonb, + '{}'::jsonb, + '""'::jsonb + ) + THEN 0 + ELSE 1 + END AS record_count + FROM jsonb_each(snapshot.data) AS source_items + ) AS source_counts + WHERE source_counts.record_count > 0 + ), + '[]'::jsonb + ) + """ + ) + + +def clear_data_source_counts(apps, schema_editor): + snapshot_model = apps.get_model("organizations", "OrganizationDataSnapshot") + snapshot_model.objects.update(data_source_counts=[]) + + +class Migration(migrations.Migration): + dependencies = [ + ("organizations", "0004_seed_daily_snapshot_refresh_schedule"), + ] + + operations = [ + migrations.AddField( + model_name="organizationdatasnapshot", + name="data_source_counts", + field=models.JSONField( + default=list, + help_text="Готовый JSON data_sources для API v2", + verbose_name="счетчики источников", + ), + ), + migrations.RunPython( + backfill_data_source_counts, + reverse_code=clear_data_source_counts, + ), + ] diff --git a/src/organizations/models.py b/src/organizations/models.py index b7b5535..d5ca29c 100644 --- a/src/organizations/models.py +++ b/src/organizations/models.py @@ -6,6 +6,7 @@ from django.db import models from django.db.models import Q from django.utils.translation import gettext_lazy as _ +from organizations.data_sources import snapshot_data_source_summary from organizations.name_normalization import normalize_organization_name @@ -116,6 +117,11 @@ class OrganizationDataSnapshot(models.Model): default=list, help_text=_("Готовый JSON registries для API v2"), ) + data_source_counts = models.JSONField( + _("счетчики источников"), + default=list, + help_text=_("Готовый JSON data_sources для API v2"), + ) updated_at = models.DateTimeField( _("дата обновления"), auto_now=True, @@ -129,3 +135,14 @@ class OrganizationDataSnapshot(models.Model): def __str__(self) -> str: return f"Snapshot for {self.organization_id}" + + def save(self, *args, **kwargs) -> None: + update_fields = kwargs.get("update_fields") + if update_fields is None or "data" in update_fields: + self.data_source_counts = snapshot_data_source_summary(self.data) + if update_fields is not None: + kwargs["update_fields"] = list( + dict.fromkeys([*update_fields, "data_source_counts"]) + ) + + super().save(*args, **kwargs) diff --git a/src/organizations/serializers.py b/src/organizations/serializers.py index 9150bab..51e526d 100644 --- a/src/organizations/serializers.py +++ b/src/organizations/serializers.py @@ -4,7 +4,10 @@ from typing import Any from rest_framework import serializers -from organizations.api_enrichment import to_api_data_source +from organizations.data_sources import ( + data_source_summary, + snapshot_data_with_api_keys, +) from organizations.models import Organization @@ -36,8 +39,11 @@ class OrganizationSerializer(serializers.ModelSerializer): def get_data(self, obj) -> dict[str, Any]: snapshot = getattr(obj, "data_snapshot", None) if snapshot is not None: - data = _snapshot_data_with_api_keys(snapshot.data) data_sources = self.context.get("data_sources") + if data_sources is not None and not data_sources: + return {} + + data = snapshot_data_with_api_keys(snapshot.data) if data_sources is None: return data return { @@ -61,13 +67,17 @@ class OrganizationSerializer(serializers.ModelSerializer): def get_data_sources(self, obj) -> list[dict[str, int | str]]: snapshot = getattr(obj, "data_snapshot", None) if snapshot is not None: - data = _snapshot_data_with_api_keys(snapshot.data) - return _data_source_summary(data) + data_source_counts = getattr(snapshot, "data_source_counts", None) + if data_source_counts: + return data_source_counts + if "data" in snapshot.get_deferred_fields(): + return [] + return data_source_summary(snapshot_data_with_api_keys(snapshot.data)) enrichment = self.context.get("enrichment", {}).get(str(obj.uid)) if enrichment is None: return [] - return _data_source_summary(enrichment.data_presence) + return data_source_summary(enrichment.data_presence) def get_registries(self, obj) -> list[dict[str, str]]: snapshot = getattr(obj, "data_snapshot", None) @@ -84,22 +94,3 @@ class OrganizationSerializer(serializers.ModelSerializer): } for registry in enrichment.registries ] - - -def _snapshot_data_with_api_keys(data: dict[str, Any]) -> dict[str, Any]: - return {to_api_data_source(source): value for source, value in data.items()} - - -def _data_source_summary(data: dict[str, Any]) -> list[dict[str, int | str]]: - summary: list[dict[str, int | str]] = [] - for source in sorted(data): - value = data[source] - if isinstance(value, list): - count = len(value) - elif value: - count = 1 - else: - count = 0 - if count: - summary.append({"source": source, "count": count}) - return summary diff --git a/src/organizations/services.py b/src/organizations/services.py index 3e9a34a..d6c581e 100644 --- a/src/organizations/services.py +++ b/src/organizations/services.py @@ -23,6 +23,7 @@ from django.utils import timezone from registers.models import Organization as RegisterOrganization from organizations.api_enrichment import OrganizationApiEnrichmentService +from organizations.data_sources import data_source_summary from organizations.models import Organization, OrganizationDataSnapshot _QUOTE_CHARS = "\"'«»„“”" @@ -132,6 +133,7 @@ class OrganizationDataSnapshotRefreshService: processed += 1 item = enrichment[str(organization.uid)] data = item.data_presence + data_source_counts = data_source_summary(data) registries = [ { "id": registry.id, @@ -146,12 +148,14 @@ class OrganizationDataSnapshotRefreshService: OrganizationDataSnapshot( organization=organization, data=data, + data_source_counts=data_source_counts, registries=registries, ) ) continue snapshot.data = data + snapshot.data_source_counts = data_source_counts snapshot.registries = registries snapshot.updated_at = timezone.now() update_instances.append(snapshot) @@ -165,7 +169,7 @@ class OrganizationDataSnapshotRefreshService: if update_instances: OrganizationDataSnapshot.objects.bulk_update( update_instances, - fields=["data", "registries", "updated_at"], + fields=["data", "data_source_counts", "registries", "updated_at"], batch_size=batch_size, ) updated += len(update_instances) diff --git a/src/organizations/views.py b/src/organizations/views.py index 5adcb11..b2f01de 100644 --- a/src/organizations/views.py +++ b/src/organizations/views.py @@ -285,9 +285,7 @@ class CachedReadOnlyMixin: class OrganizationViewSet(CachedReadOnlyMixin, ReadOnlyModelViewSet): """Read-only API for canonical organizations.""" - queryset = Organization.objects.select_related("data_snapshot").order_by( - "name", "uid" - ) + queryset = Organization.objects.order_by("name", "uid") serializer_class = OrganizationSerializer permission_classes = [IsAuthenticated] lookup_field = "uid" @@ -307,7 +305,10 @@ class OrganizationViewSet(CachedReadOnlyMixin, ReadOnlyModelViewSet): return super().get_permissions() def get_queryset(self): - queryset = super().get_queryset() + queryset = super().get_queryset().select_related("data_snapshot") + if self._should_defer_snapshot_data(): + queryset = queryset.defer("data_snapshot__data") + if self.action != "list" or "has_registry" in self.request.query_params: return queryset @@ -320,6 +321,20 @@ class OrganizationViewSet(CachedReadOnlyMixin, ReadOnlyModelViewSet): return filterset.qs return queryset + def _should_defer_snapshot_data(self) -> bool: + if getattr(self, "action", None) != "list": + return False + + return not any( + name in self.request.query_params + for name in ( + "data", + "data_sources", + "exclude_data", + "exclude_data_sources", + ) + ) + @swagger_auto_schema( tags=[ORGANIZATIONS_TAG], operation_id="v2_organizations_list", diff --git a/tests/apps/organizations/test_api_v2.py b/tests/apps/organizations/test_api_v2.py index c620766..88c007f 100644 --- a/tests/apps/organizations/test_api_v2.py +++ b/tests/apps/organizations/test_api_v2.py @@ -12,6 +12,7 @@ from django.test import override_settings from django.test.utils import CaptureQueriesContext from django.urls import reverse from organizations.cache import invalidate_organization_api_cache +from organizations.filters import OrganizationFilter from organizations.models import Organization, OrganizationDataSnapshot from organizations.services import OrganizationDataSnapshotRefreshService from rest_framework import status @@ -195,6 +196,46 @@ class OrganizationsApiV2Test(APITestCase): ], ) + def test_list_default_uses_snapshot_summary_without_loading_full_data(self): + organization = Organization.objects.create( + name='ООО "Легкий снапшот"', + inn="7712345685", + kpp="771201008", + ogrn="1027700132203", + ) + OrganizationDataSnapshot.objects.create( + organization=organization, + data={ + "industrial": [{"id": index} for index in range(100)], + "fns_reports": [{"id": index} for index in range(50)], + }, + registries=[], + ) + + with CaptureQueriesContext(connection) as captured: + response = self.client.get( + reverse("api_v2:organizations:organizations-list"), + {"inn": organization.inn, "has_registry": "false"}, + ) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + item = response.data["data"][0] + self.assertEqual(item["data"], {}) + self.assertEqual( + item["data_sources"], + [ + {"source": "fns_reports", "count": 50}, + {"source": "industrial", "count": 100}, + ], + ) + full_snapshot_data_queries = [ + query["sql"] + for query in captured + if "ORGANIZATIONS_DATA_SNAPSHOT" in query["sql"].upper() + and '"data"' in query["sql"] + ] + self.assertEqual(full_snapshot_data_queries, []) + def test_list_returns_snapshot_data_when_sources_are_requested(self): organization = Organization.objects.create( name='ООО "Явные данные"', @@ -740,6 +781,18 @@ class OrganizationsApiV2Test(APITestCase): self.assertEqual(no_registry.data["meta"]["pagination"]["total_count"], 1) self.assertEqual(no_registry.data["data"][0]["uid"], str(without_registry.uid)) + def test_has_registry_filter_uses_uncorrelated_identity_subqueries(self): + filterset = OrganizationFilter( + data={"has_registry": "true"}, + queryset=Organization.objects.all(), + ) + + self.assertTrue(filterset.is_valid(), filterset.errors) + sql = str(filterset.qs.query).upper() + + self.assertIn(" IN ", sql) + self.assertNotIn("EXISTS", sql) + def test_list_defaults_to_has_registry_true(self): with_registry = Organization.objects.create( name='ООО "Дефолтный реестр"', @@ -852,6 +905,39 @@ class OrganizationsApiV2Test(APITestCase): self.assertEqual(has_fns.data["meta"]["pagination"]["total_count"], 1) self.assertEqual(has_fns.data["data"][0]["uid"], str(with_fns.uid)) + def test_has_fns_reports_filter_does_not_preload_report_identities(self): + organization = Organization.objects.create( + name='ООО "Отчетность без preload"', + inn="7800000004", + kpp="780001004", + ogrn="1027700133004", + ) + FinancialReport.objects.create( + external_id="fin-presence-no-preload", + ogrn=organization.ogrn, + file_name="fin_presence_no_preload.xlsx", + file_hash="d" * 64, + load_batch=1, + status=FinancialReport.Status.SUCCESS, + source=FinancialReport.SourceType.API, + ) + + with CaptureQueriesContext(connection) as captured: + response = self.client.get( + reverse("api_v2:organizations:organizations-list"), + {"has_fns_reports": "true", "has_registry": "false"}, + ) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data["meta"]["pagination"]["total_count"], 1) + distinct_report_queries = [ + query["sql"] + for query in captured + if "PARSERS_FINANCIAL_REPORT" in query["sql"].upper() + and "SELECT DISTINCT" in query["sql"].upper() + ] + self.assertEqual(distinct_report_queries, []) + def test_limits_response_data_sources(self): organization = Organization.objects.create( name='ООО "Источник"', diff --git a/tests/apps/organizations/test_services.py b/tests/apps/organizations/test_services.py index 1bdf659..559dcfa 100644 --- a/tests/apps/organizations/test_services.py +++ b/tests/apps/organizations/test_services.py @@ -99,6 +99,10 @@ class OrganizationDataSnapshotRefreshServiceTest(TestCase): snapshot.data["industrial"][0]["certificate_number"], "SNAPSHOT-BATCH-CERT", ) + self.assertEqual( + snapshot.data_source_counts, + [{"source": "industrial", "count": 1}], + ) def test_refresh_for_fns_batch_matches_by_ogrn(self): organization = Organization.objects.create(