beerds/server/modules/recognizer/repository/repository.py

138 lines
5.5 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from functools import lru_cache
from uuid import uuid4
import ujson
from aiocache import Cache, cached # type: ignore
from sqlalchemy import insert, select
from server.infra.db import AsyncDB
from server.modules.recognizer.repository import models as rm
@dataclass
class ResultWithBeerds:
result: rm.Results
beerds: list[rm.ResultBeerds]
class ARecognizerRepository(metaclass=ABCMeta):
@abstractmethod
async def images_dogs(self) -> dict:
pass
@abstractmethod
async def images_cats(self) -> dict:
pass
@abstractmethod
def labels_dogs(self) -> dict:
pass
@abstractmethod
def labels_cats(self) -> dict:
pass
@abstractmethod
async def get_results(self) -> list[ResultWithBeerds]:
"""Получить **все** результаты (кэшируется)."""
@abstractmethod
async def create_result_with_beerds(self, result: rm.Results, beerd_ids: list[str]) -> None:
"""
Создать новый результат и сразу же вставить связанные `ResultBeerds`.
`beerd_ids` список id пород, которые должны быть привязаны к
результату. Если список пуст, создаётся только результат.
"""
class RecognizerRepository(ARecognizerRepository):
def __init__(self, db: AsyncDB):
self._db = db
@cached(ttl=60, cache=Cache.MEMORY)
async def images_dogs(self) -> dict:
with open("server/modules/recognizer/repository/meta/images.json") as f: # noqa: ASYNC230
return ujson.loads(f.read())["dog"]
@cached(ttl=60, cache=Cache.MEMORY)
async def images_cats(self) -> dict:
with open("server/modules/recognizer/repository/meta/images.json") as f: # noqa: ASYNC230
return ujson.loads(f.read())["cat"]
@lru_cache
def labels_cats(self) -> dict:
with open("server/modules/recognizer/repository/meta/labels_cats.json") as f: # noqa: ASYNC230
data_labels = f.read()
return ujson.loads(data_labels)
@lru_cache
def labels_dogs(self) -> dict:
with open("server/modules/recognizer/repository/meta/labels_dogs.json") as f: # noqa: ASYNC230
data_labels = f.read()
return ujson.loads(data_labels)
async def create_result_with_beerds(self, result: rm.Results, beerd_ids: list[str]) -> None:
"""
Создаёт запись в ``recognizer_results`` и сразу же добавляет
одну запись в ``recognizer_results_beerds`` (если передан список
beerd_id) все в одном `INSERT`‑запросе и одной транзакции.
При отсутствии `id` у результата генерируется uuid4.
"""
# --------------------------------------------------------------------
# 1⃣ Подготовим объект результата
# --------------------------------------------------------------------
if not result.id:
result.id = str(uuid4())
# --------------------------------------------------------------------
# 2⃣ Откроем транзакцию и добавим результат
# --------------------------------------------------------------------
async with self._db.async_session() as session:
async with session.begin(): # начинается транзакция
session.add(result) # INSERT recognizer_results
# Если есть связанные beerd, делаем один bulkINSERT
if beerd_ids:
values = [
{
"id": str(uuid4()),
"recognizer_results_id": result.id,
"beerd_id": beerd_id,
}
for beerd_id in beerd_ids
]
# Один INSERT … VALUES (..), (..), …
await session.execute(insert(rm.ResultBeerds).values(values))
await session.commit() # завершаем транзакцию
@cached(ttl=60, cache=Cache.MEMORY)
async def get_results(self) -> list[ResultWithBeerds]:
async with self._db.async_session() as session:
# 1⃣ Получаем все результаты
stmt_res = select(rm.Results)
res_res = await session.execute(stmt_res)
results = res_res.scalars().all()
if not results:
return []
# 2⃣ Получаем все beerds, относящиеся к этим результатам
res_ids = [r.id for r in results]
stmt_beerds = (
select(rm.ResultBeerds).where(rm.ResultBeerds.recognizer_results_id.in_(res_ids)) # type:ignore
)
res_beerds = await session.execute(stmt_beerds)
beerds_list = res_beerds.scalars().all()
# 3⃣ Формируем карту id → beerds
by_res: dict[str, list[rm.ResultBeerds]] = {}
for b in beerds_list:
by_res.setdefault(b.recognizer_results_id, []).append(b)
# 4⃣ Собираем DTOы
return [ResultWithBeerds(result=r, beerds=by_res.get(r.id, [])) for r in results]