177 lines
6.5 KiB
Python
177 lines
6.5 KiB
Python
import asyncio
|
|
import io
|
|
import os
|
|
from dataclasses import dataclass
|
|
from datetime import UTC, datetime
|
|
from typing import Any, NewType, Protocol
|
|
from uuid import uuid4
|
|
|
|
from dataclasses_ujson.dataclasses_ujson import UJsonMixin # type: ignore
|
|
from PIL import Image
|
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
|
|
import torch
|
|
from torchvision import transforms # type: ignore
|
|
|
|
from server.modules.attachments.domains.attachments import Attachment
|
|
from server.modules.descriptions.repository import ACharactersRepository, Breed
|
|
from server.modules.recognizer.repository import ARecognizerRepository, models
|
|
|
|
TorchModel = NewType("TorchModel", torch.nn.Module)
|
|
|
|
|
|
def load_model(model_path, device="cpu") -> TorchModel:
|
|
model = torch.load(model_path, map_location=device, weights_only=False)
|
|
model.eval()
|
|
return TorchModel(model)
|
|
|
|
|
|
DOG_MODEL = load_model("server/models/dogs_model.pth")
|
|
CAT_MODEL = load_model("server/models/cats_model.pth")
|
|
|
|
|
|
class AttachmentService(Protocol):
|
|
async def create(self, file: bytes) -> Attachment:
|
|
pass
|
|
|
|
|
|
@dataclass
|
|
class ResultImages(UJsonMixin):
|
|
name: str
|
|
url: list[str]
|
|
|
|
|
|
@dataclass
|
|
class RecognizerResult(UJsonMixin):
|
|
results: dict
|
|
images: list
|
|
description: dict[str, list] | None
|
|
uploaded_attach_id: str | None
|
|
|
|
@dataclass
|
|
class SharingBeerds(UJsonMixin):
|
|
alias: str
|
|
name: str
|
|
images: list[ResultImages]
|
|
|
|
@dataclass
|
|
class SharingResult(UJsonMixin):
|
|
beerds: list[SharingBeerds]
|
|
attachment_id: str
|
|
|
|
|
|
class RecognizerService:
|
|
__slots__ = ("_repository", "_attachment_service", "_repository_characters")
|
|
|
|
def __init__(
|
|
self,
|
|
repository: ARecognizerRepository,
|
|
attachment_service: AttachmentService,
|
|
repository_characters: ACharactersRepository,
|
|
):
|
|
self._repository = repository
|
|
self._attachment_service = attachment_service
|
|
self._repository_characters = repository_characters
|
|
|
|
async def images_cats(self) -> dict:
|
|
return await self._repository.images_cats()
|
|
|
|
async def images_dogs(self) -> dict:
|
|
return await self._repository.images_dogs()
|
|
|
|
async def create_result(self, attachment: Attachment, user_id: str, device_id: str, beerd_names: list[str]):
|
|
characters = await self._repository_characters.get_characters()
|
|
await self._repository.create_result_with_beerds(
|
|
models.Results(
|
|
id=str(uuid4()),
|
|
attachment_id=attachment.id,
|
|
user_id=user_id,
|
|
device_id=device_id,
|
|
created_at=datetime.now(UTC),
|
|
),
|
|
[ch.id for ch in characters if ch.name in beerd_names],
|
|
)
|
|
|
|
async def get_results(self, result_id: str) -> SharingResult:
|
|
results = await self._repository.get_results()
|
|
beerds_store: dict[str, Breed] = {b.id: b for b in await self._repository_characters.get_characters()}
|
|
images_dogs = await self._repository.images_dogs()
|
|
for r in results:
|
|
if r.result.id != result_id:
|
|
continue
|
|
beers: list[SharingBeerds] = []
|
|
for beerd in r.beerds:
|
|
beers.append(SharingBeerds(
|
|
alias=beerds_store[beerd.beerd_id].alias,
|
|
name=beerds_store[beerd.beerd_id].name,
|
|
images = [ResultImages(name=beerds_store[beerd.beerd_id].name, url=[f"/static/assets/cat/{beerds_store[beerd.beerd_id].name}/{i}" for i in images_dogs[beerds_store[beerd.beerd_id].name.replace(" ", "_")]])]
|
|
))
|
|
return SharingResult(beerds=beers, attachment_id=r.result.attachment_id)
|
|
|
|
async def predict_dog_image(self, image: bytes, user_id: str, device_id: str | None) -> RecognizerResult:
|
|
if device_id is None:
|
|
device_id = "mobile"
|
|
attachment = await self._attachment_service.create(image)
|
|
predicted_data = self._predict(image, DOG_MODEL)
|
|
results = {}
|
|
images = []
|
|
description: dict[str, list] = {}
|
|
images_dogs = await self._repository.images_dogs()
|
|
for d in predicted_data:
|
|
predicted_idx, probabilities = d
|
|
predicted_label: str = self._repository.labels_dogs()[str(predicted_idx)]
|
|
name = predicted_label.replace("_", " ")
|
|
images.append(
|
|
ResultImages(
|
|
name=name, url=[f"/static/assets/dog/{predicted_label}/{i}" for i in images_dogs[predicted_label]]
|
|
)
|
|
)
|
|
description.setdefault(name, []).append(f"/dogs-characteristics/{name.replace(' ', '_')}")
|
|
results[probabilities] = name
|
|
asyncio.create_task(self.create_result(attachment, user_id, device_id, [results[key] for key in results]))
|
|
return RecognizerResult(
|
|
results=results, images=images, description=description, uploaded_attach_id=attachment.id
|
|
)
|
|
|
|
async def predict_cat_image(self, image: bytes) -> RecognizerResult:
|
|
attachment = await self._attachment_service.create(image)
|
|
predicted_data = self._predict(image, CAT_MODEL)
|
|
results = {}
|
|
images = []
|
|
images_cats = await self._repository.images_cats()
|
|
for d in predicted_data:
|
|
predicted_idx, probabilities = d
|
|
predicted_label: str = self._repository.labels_cats()[str(predicted_idx)]
|
|
name = predicted_label.replace("_", " ")
|
|
images.append(
|
|
ResultImages(
|
|
name=name, url=[f"/static/assets/cat/{predicted_label}/{i}" for i in images_cats[predicted_label]]
|
|
)
|
|
)
|
|
results[probabilities] = name
|
|
return RecognizerResult(results=results, images=images, description=None, uploaded_attach_id=attachment.id)
|
|
|
|
def _predict(self, image: bytes, model, device="cpu") -> list[Any]:
|
|
img_size = (224, 224)
|
|
preprocess = transforms.Compose(
|
|
[
|
|
transforms.Resize(img_size),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
|
]
|
|
)
|
|
input_tensor = preprocess(Image.open(io.BytesIO(image)))
|
|
input_batch = input_tensor.unsqueeze(0).to(device) # Добавляем dimension для батча
|
|
|
|
with torch.no_grad():
|
|
output = model(input_batch)
|
|
|
|
probabilities = torch.nn.functional.softmax(output[0], dim=0)
|
|
k = 5
|
|
topk_probs, predicted_idx = torch.topk(probabilities, k)
|
|
|
|
predicted_data = []
|
|
for i in range(k):
|
|
predicted_data.append((predicted_idx[i].item(), float(topk_probs[i].item())))
|
|
return predicted_data
|