from pathlib import Path from PIL import Image from starlite import ( Controller, StaticFilesConfig, get, post, Body, MediaType, RequestEncodingType, Starlite, UploadFile, Template, TemplateConfig, ) from starlite.contrib.jinja import JinjaTemplateEngine import io import os import json import requests import torch from torchvision import transforms # type: ignore import torch.nn.functional as F os.environ["CUDA_VISIBLE_DEVICES"] = "-1" dict_names = {} with open("beerds.json", "r") as f: dict_names = json.loads(f.read()) for key in dict_names: dict_names[key] = dict_names[key].replace("_", " ") VK_URL = "https://api.vk.com/method/" TOKEN = os.getenv("VK_TOKEN") headers = {"Authorization": f"Bearer {TOKEN}"} group_id = 220240483 postfix = "?v=5.131" IMAGES = {} def get_images(): global IMAGES r = requests.get( f"{VK_URL}photos.getAll{postfix}&access_token={TOKEN}&owner_id=-{group_id}&count=200" ) items = r.json().get("response").get("items") for item in items: for s in item.get("sizes"): if s.get("type") != "x": continue IMAGES[item.get("text")] = s.get("url") break def load_model(model_path, device="cpu"): model = torch.load(model_path, map_location=device, weights_only=False) model.eval() return model get_images() MODEL = load_model("./models/dogs_model.pth") with open("labels.json", "r") as f: data_labels = f.read() labels_dict = json.loads(data_labels) def predict_image(image, model, device="cuda"): img_size = (180, 180) preprocess = transforms.Compose( [ transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ] ) input_tensor = preprocess(image) input_batch = input_tensor.unsqueeze(0).to(device) # Добавляем dimension для батча with torch.no_grad(): output = model(input_batch) probabilities = F.softmax(output[0], dim=0) _, predicted_idx = torch.max(probabilities, 0) return predicted_idx.item(), probabilities.cpu().numpy() class BeerdsController(Controller): path = "/breeds" @post("/", media_type=MediaType.TEXT) async def beeds( self, data: UploadFile = Body(media_type=RequestEncodingType.MULTI_PART) ) -> dict: body = await data.read() img_file = Image.open(io.BytesIO(body)) predicted_idx, probabilities = predict_image(img_file, MODEL, "cpu") predicted_label = labels_dict[str(predicted_idx)] images = [{"name": predicted_label, "url": IMAGES[predicted_label]}] return { "results": {probabilities: predicted_label}, "images": images, } class BaseController(Controller): path = "/" @get("/") async def main(self) -> Template: return Template(name="index.html") @get("/sitemap.xml", media_type=MediaType.XML) async def sitemaps(self) -> bytes: return """ https://xn-----6kcp3cadbabfh8a0a.xn--p1ai/ 2023-05-01T19:01:03+00:00 """.encode() @get("/robots.txt", media_type=MediaType.TEXT) async def robots(self) -> str: return """ User-agent: * Allow: / Sitemap: https://xn-----6kcp3cadbabfh8a0a.xn--p1ai/sitemap.xml """ app = Starlite( route_handlers=[BeerdsController, BaseController], static_files_config=[ StaticFilesConfig(directories=[Path("static")], path="/static"), ], template_config=TemplateConfig( directory=Path("templates"), engine=JinjaTemplateEngine, ), )