145 lines
5.7 KiB
Python
145 lines
5.7 KiB
Python
from abc import ABCMeta, abstractmethod
|
||
from dataclasses import dataclass
|
||
from functools import lru_cache
|
||
from uuid import uuid4
|
||
|
||
import ujson
|
||
from aiocache import Cache, cached # type: ignore
|
||
from sqlalchemy import insert, select
|
||
|
||
from server.infra.db import AsyncDB
|
||
from server.modules.recognizer.repository import models as rm
|
||
|
||
|
||
@dataclass
|
||
class ResultWithBeerds:
|
||
result: rm.Results
|
||
beerds: list[rm.ResultBeerds]
|
||
|
||
|
||
@dataclass
|
||
class ResultBeerds:
|
||
beerd_id: str
|
||
probability: float
|
||
|
||
|
||
class ARecognizerRepository(metaclass=ABCMeta):
|
||
@abstractmethod
|
||
async def images_dogs(self) -> dict:
|
||
pass
|
||
|
||
@abstractmethod
|
||
async def images_cats(self) -> dict:
|
||
pass
|
||
|
||
@abstractmethod
|
||
def labels_dogs(self) -> dict:
|
||
pass
|
||
|
||
@abstractmethod
|
||
def labels_cats(self) -> dict:
|
||
pass
|
||
|
||
@abstractmethod
|
||
async def get_results(self) -> list[ResultWithBeerds]:
|
||
"""Получить **все** результаты (кэшируется)."""
|
||
|
||
@abstractmethod
|
||
async def create_result_with_beerds(self, result: rm.Results, beerd_ids: list[ResultBeerds]) -> None:
|
||
"""
|
||
Создать новый результат и сразу же вставить связанные `ResultBeerds`.
|
||
|
||
`beerd_ids` – список id пород, которые должны быть привязаны к
|
||
результату. Если список пуст, создаётся только результат.
|
||
"""
|
||
|
||
|
||
class RecognizerRepository(ARecognizerRepository):
|
||
def __init__(self, db: AsyncDB):
|
||
self._db = db
|
||
|
||
@cached(ttl=60, cache=Cache.MEMORY)
|
||
async def images_dogs(self) -> dict:
|
||
with open("server/modules/recognizer/repository/meta/images.json") as f: # noqa: ASYNC230
|
||
return ujson.loads(f.read())["dog"]
|
||
|
||
@cached(ttl=60, cache=Cache.MEMORY)
|
||
async def images_cats(self) -> dict:
|
||
with open("server/modules/recognizer/repository/meta/images.json") as f: # noqa: ASYNC230
|
||
return ujson.loads(f.read())["cat"]
|
||
|
||
@lru_cache
|
||
def labels_cats(self) -> dict:
|
||
with open("server/modules/recognizer/repository/meta/labels_cats.json") as f: # noqa: ASYNC230
|
||
data_labels = f.read()
|
||
return ujson.loads(data_labels)
|
||
|
||
@lru_cache
|
||
def labels_dogs(self) -> dict:
|
||
with open("server/modules/recognizer/repository/meta/labels_dogs.json") as f: # noqa: ASYNC230
|
||
data_labels = f.read()
|
||
return ujson.loads(data_labels)
|
||
|
||
async def create_result_with_beerds(self, result: rm.Results, beerd_ids: list[ResultBeerds]) -> None:
|
||
"""
|
||
Создаёт запись в ``recognizer_results`` и сразу же добавляет
|
||
одну запись в ``recognizer_results_beerds`` (если передан список
|
||
beerd_id) – все в одном `INSERT`‑запросе и одной транзакции.
|
||
|
||
При отсутствии `id` у результата генерируется uuid4.
|
||
"""
|
||
# --------------------------------------------------------------------
|
||
# 1️⃣ Подготовим объект результата
|
||
# --------------------------------------------------------------------
|
||
if not result.id:
|
||
result.id = str(uuid4())
|
||
|
||
# --------------------------------------------------------------------
|
||
# 2️⃣ Откроем транзакцию и добавим результат
|
||
# --------------------------------------------------------------------
|
||
async with self._db.async_session() as session:
|
||
async with session.begin(): # начинается транзакция
|
||
session.add(result) # INSERT recognizer_results
|
||
|
||
# Если есть связанные beerd, делаем один bulk‑INSERT
|
||
if beerd_ids:
|
||
values = [
|
||
{
|
||
"id": str(uuid4()),
|
||
"recognizer_results_id": result.id,
|
||
"beerd_id": beerd_id.beerd_id,
|
||
"probability": beerd_id.probability,
|
||
}
|
||
for beerd_id in beerd_ids
|
||
]
|
||
# Один INSERT … VALUES (..), (..), …
|
||
await session.execute(insert(rm.ResultBeerds).values(values))
|
||
await session.commit() # завершаем транзакцию
|
||
|
||
@cached(ttl=60, cache=Cache.MEMORY)
|
||
async def get_results(self) -> list[ResultWithBeerds]:
|
||
async with self._db.async_session() as session:
|
||
# 1️⃣ Получаем все результаты
|
||
stmt_res = select(rm.Results)
|
||
res_res = await session.execute(stmt_res)
|
||
results = res_res.scalars().all()
|
||
|
||
if not results:
|
||
return []
|
||
|
||
# 2️⃣ Получаем все beerds, относящиеся к этим результатам
|
||
res_ids = [r.id for r in results]
|
||
stmt_beerds = (
|
||
select(rm.ResultBeerds).where(rm.ResultBeerds.recognizer_results_id.in_(res_ids)) # type:ignore
|
||
)
|
||
res_beerds = await session.execute(stmt_beerds)
|
||
beerds_list = res_beerds.scalars().all()
|
||
|
||
# 3️⃣ Формируем карту id → beerds
|
||
by_res: dict[str, list[rm.ResultBeerds]] = {}
|
||
for b in beerds_list:
|
||
by_res.setdefault(b.recognizer_results_id, []).append(b)
|
||
|
||
# 4️⃣ Собираем DTO‑ы
|
||
return [ResultWithBeerds(result=r, beerds=by_res.get(r.id, [])) for r in results]
|