from pathlib import Path from PIL import Image from starlite import ( Controller, StaticFilesConfig, get, post, Body, MediaType, RequestEncodingType, Starlite, UploadFile, Template, TemplateConfig, ) from starlite.contrib.jinja import JinjaTemplateEngine import io import os import json import torch from torchvision import transforms # type: ignore import torch.nn.functional as F os.environ["CUDA_VISIBLE_DEVICES"] = "-1" def load_model(model_path, device="cpu"): model = torch.load(model_path, map_location=device, weights_only=False) model.eval() return model with open("server/meta/images.json", "r") as f: IMAGES = json.loads(f.read()) DOG_MODEL = load_model("server/models/dogs_model.pth") CAT_MODEL = load_model("server/models/cats_model.pth") with open("server/meta/labels_dogs.json", "r") as f: data_labels = f.read() labels_dogs = json.loads(data_labels) with open("server/meta/labels_cats.json", "r") as f: data_labels = f.read() labels_cats = json.loads(data_labels) def predict_image(image, model, device="cuda") -> list[tuple]: img_size = (224, 224) preprocess = transforms.Compose( [ transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ] ) input_tensor = preprocess(image) input_batch = input_tensor.unsqueeze(0).to(device) # Добавляем dimension для батча with torch.no_grad(): output = model(input_batch) probabilities = torch.nn.functional.softmax(output[0], dim=0) k = 5 topk_probs, predicted_idx = torch.topk(probabilities, k) data = [] for i in range(k): data.append((predicted_idx[i].item(), float(topk_probs[i].item()))) return data class BeerdsController(Controller): path = "/beerds" @post("/dogs") async def beerds_dogs( self, data: UploadFile = Body(media_type=RequestEncodingType.MULTI_PART) ) -> dict: body = await data.read() img_file = Image.open(io.BytesIO(body)) predicted_data = predict_image(img_file, DOG_MODEL, "cpu") results = {} images = [] for d in predicted_data: predicted_idx, probabilities = d predicted_label = labels_dogs[str(predicted_idx)] name = predicted_label.replace("_", " ") images.append({"name": name, "url": IMAGES[name]}) results[probabilities] = name return { "results": results, "images": images, } @post("/cats") async def beerds_cats( self, data: UploadFile = Body(media_type=RequestEncodingType.MULTI_PART) ) -> dict: body = await data.read() img_file = Image.open(io.BytesIO(body)) predicted_data = predict_image(img_file, CAT_MODEL, "cpu") results = {} for d in predicted_data: predicted_idx, probabilities = d predicted_label = labels_cats[str(predicted_idx)] results[probabilities] = predicted_label return { "results": results, "images": [], } class BaseController(Controller): path = "/" @get("/") async def dogs(self) -> Template: return Template(name="dogs.html") @get("/cats") async def cats(self) -> Template: return Template(name="cats.html") @get("/contacts") async def contacts(self) -> Template: return Template(name="contacts.html") @get("/sitemap.xml", media_type=MediaType.XML) async def sitemaps(self) -> bytes: return """ https://xn-----6kcp3cadbabfh8a0a.xn--p1ai/ 2025-04-21T19:01:03+00:00 https://xn-----6kcp3cadbabfh8a0a.xn--p1ai/cats 2025-04-21T19:01:03+00:00 """.encode() @get("/robots.txt", media_type=MediaType.TEXT) async def robots(self) -> str: return """ User-agent: * Allow: / Sitemap: https://xn-----6kcp3cadbabfh8a0a.xn--p1ai/sitemap.xml """ app = Starlite( debug=True, route_handlers=[BeerdsController, BaseController], static_files_config=[ StaticFilesConfig(directories=[Path("server/static")], path="/static"), ], template_config=TemplateConfig( directory=Path("server/templates"), engine=JinjaTemplateEngine, ), )