beerds/server/main.py

160 lines
4.0 KiB
Python

from pathlib import Path
from PIL import Image
from starlite import (
Controller,
StaticFilesConfig,
get,
post,
Body,
MediaType,
RequestEncodingType,
Starlite,
UploadFile,
Template,
TemplateConfig,
)
from starlite.contrib.jinja import JinjaTemplateEngine
import io
import os
import json
import requests
import torch
from torchvision import transforms # type: ignore
import torch.nn.functional as F
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
dict_names = {}
with open("beerds.json", "r") as f:
dict_names = json.loads(f.read())
for key in dict_names:
dict_names[key] = dict_names[key].replace("_", " ")
VK_URL = "https://api.vk.com/method/"
TOKEN = os.getenv("VK_TOKEN")
headers = {"Authorization": f"Bearer {TOKEN}"}
group_id = 220240483
postfix = "?v=5.131"
IMAGES = {}
def get_images():
global IMAGES
r = requests.get(
f"{VK_URL}photos.getAll{postfix}&access_token={TOKEN}&owner_id=-{group_id}&count=200"
)
items = r.json().get("response").get("items")
for item in items:
for s in item.get("sizes"):
if s.get("type") != "x":
continue
IMAGES[item.get("text")] = s.get("url")
break
def load_model(model_path, device="cpu"):
model = torch.load(model_path, map_location=device, weights_only=False)
model.eval()
return model
get_images()
MODEL = load_model("./models/dogs_model.pth")
with open("labels.json", "r") as f:
data_labels = f.read()
labels_dict = json.loads(data_labels)
def predict_image(image, model, device="cuda"):
img_size = (180, 180)
preprocess = transforms.Compose(
[
transforms.Resize(img_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
input_tensor = preprocess(image)
input_batch = input_tensor.unsqueeze(0).to(device) # Добавляем dimension для батча
with torch.no_grad():
output = model(input_batch)
probabilities = F.softmax(output[0], dim=0)
_, predicted_idx = torch.max(probabilities, 0)
return predicted_idx.item(), probabilities.cpu().numpy()
class BeerdsController(Controller):
path = "/breeds"
@post("/", media_type=MediaType.TEXT)
async def beeds(
self, data: UploadFile = Body(media_type=RequestEncodingType.MULTI_PART)
) -> dict:
body = await data.read()
img_file = Image.open(io.BytesIO(body))
predicted_idx, probabilities = predict_image(img_file, MODEL, "cpu")
predicted_label = labels_dict[str(predicted_idx)]
images = [{"name": predicted_label, "url": IMAGES[predicted_label]}]
return {
"results": {probabilities: predicted_label},
"images": images,
}
class BaseController(Controller):
path = "/"
@get("/")
async def main(self) -> Template:
return Template(name="index.html")
@get("/sitemap.xml", media_type=MediaType.XML)
async def sitemaps(self) -> bytes:
return """<?xml version="1.0" encoding="UTF-8"?>
<urlset
xmlns="http://www.sitemaps.org/schemas/sitemap/0.9"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://www.sitemaps.org/schemas/sitemap/0.9
http://www.sitemaps.org/schemas/sitemap/0.9/sitemap.xsd">
<!-- created with Free Online Sitemap Generator www.xml-sitemaps.com -->
<url>
<loc>https://xn-----6kcp3cadbabfh8a0a.xn--p1ai/</loc>
<lastmod>2023-05-01T19:01:03+00:00</lastmod>
</url>
</urlset>
""".encode()
@get("/robots.txt", media_type=MediaType.TEXT)
async def robots(self) -> str:
return """
User-agent: *
Allow: /
Sitemap: https://xn-----6kcp3cadbabfh8a0a.xn--p1ai/sitemap.xml
"""
app = Starlite(
route_handlers=[BeerdsController, BaseController],
static_files_config=[
StaticFilesConfig(directories=[Path("static")], path="/static"),
],
template_config=TemplateConfig(
directory=Path("templates"),
engine=JinjaTemplateEngine,
),
)