from abc import ABCMeta, abstractmethod from dataclasses import dataclass from functools import lru_cache from uuid import uuid4 import ujson from aiocache import Cache, cached # type: ignore from sqlalchemy import insert, select from server.infra.db import AsyncDB from server.modules.recognizer.repository import models as rm @dataclass class ResultWithBeerds: result: rm.Results beerds: list[rm.ResultBeerds] @dataclass class ResultBeerds: beerd_id: str probability: float class ARecognizerRepository(metaclass=ABCMeta): @abstractmethod async def images_dogs(self) -> dict: pass @abstractmethod async def images_cats(self) -> dict: pass @abstractmethod def labels_dogs(self) -> dict: pass @abstractmethod def labels_cats(self) -> dict: pass @abstractmethod async def get_results(self) -> list[ResultWithBeerds]: """Получить **все** результаты (кэшируется).""" @abstractmethod async def create_result_with_beerds(self, result: rm.Results, beerd_ids: list[ResultBeerds]) -> None: """ Создать новый результат и сразу же вставить связанные `ResultBeerds`. `beerd_ids` – список id пород, которые должны быть привязаны к результату. Если список пуст, создаётся только результат. """ class RecognizerRepository(ARecognizerRepository): def __init__(self, db: AsyncDB): self._db = db @cached(ttl=60, cache=Cache.MEMORY) async def images_dogs(self) -> dict: with open("server/modules/recognizer/repository/meta/images.json") as f: # noqa: ASYNC230 return ujson.loads(f.read())["dog"] @cached(ttl=60, cache=Cache.MEMORY) async def images_cats(self) -> dict: with open("server/modules/recognizer/repository/meta/images.json") as f: # noqa: ASYNC230 return ujson.loads(f.read())["cat"] @lru_cache def labels_cats(self) -> dict: with open("server/modules/recognizer/repository/meta/labels_cats.json") as f: # noqa: ASYNC230 data_labels = f.read() return ujson.loads(data_labels) @lru_cache def labels_dogs(self) -> dict: with open("server/modules/recognizer/repository/meta/labels_dogs.json") as f: # noqa: ASYNC230 data_labels = f.read() return ujson.loads(data_labels) async def create_result_with_beerds(self, result: rm.Results, beerd_ids: list[ResultBeerds]) -> None: """ Создаёт запись в ``recognizer_results`` и сразу же добавляет одну запись в ``recognizer_results_beerds`` (если передан список beerd_id) – все в одном `INSERT`‑запросе и одной транзакции. При отсутствии `id` у результата генерируется uuid4. """ # -------------------------------------------------------------------- # 1️⃣ Подготовим объект результата # -------------------------------------------------------------------- if not result.id: result.id = str(uuid4()) # -------------------------------------------------------------------- # 2️⃣ Откроем транзакцию и добавим результат # -------------------------------------------------------------------- async with self._db.async_session() as session: async with session.begin(): # начинается транзакция session.add(result) # INSERT recognizer_results # Если есть связанные beerd, делаем один bulk‑INSERT if beerd_ids: values = [ { "id": str(uuid4()), "recognizer_results_id": result.id, "beerd_id": beerd_id.beerd_id, "probability": beerd_id.probability, } for beerd_id in beerd_ids ] # Один INSERT … VALUES (..), (..), … await session.execute(insert(rm.ResultBeerds).values(values)) await session.commit() # завершаем транзакцию @cached(ttl=60, cache=Cache.MEMORY) async def get_results(self) -> list[ResultWithBeerds]: async with self._db.async_session() as session: # 1️⃣ Получаем все результаты stmt_res = select(rm.Results) res_res = await session.execute(stmt_res) results = res_res.scalars().all() if not results: return [] # 2️⃣ Получаем все beerds, относящиеся к этим результатам res_ids = [r.id for r in results] stmt_beerds = ( select(rm.ResultBeerds).where(rm.ResultBeerds.recognizer_results_id.in_(res_ids)) # type:ignore ) res_beerds = await session.execute(stmt_beerds) beerds_list = res_beerds.scalars().all() # 3️⃣ Формируем карту id → beerds by_res: dict[str, list[rm.ResultBeerds]] = {} for b in beerds_list: by_res.setdefault(b.recognizer_results_id, []).append(b) # 4️⃣ Собираем DTO‑ы return [ResultWithBeerds(result=r, beerds=by_res.get(r.id, [])) for r in results]