beerds/server/main.py

240 lines
7.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from pathlib import Path
import markdown
from PIL import Image
from starlite import (
Controller,
StaticFilesConfig,
get,
post,
Body,
MediaType,
RequestEncodingType,
Starlite,
UploadFile,
Template,
TemplateConfig,
HTTPException
)
from starlite.contrib.jinja import JinjaTemplateEngine
import io
import os
import json
import torch
from torchvision import transforms # type: ignore
import torch.nn.functional as F
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
def load_model(model_path, device="cpu"):
model = torch.load(model_path, map_location=device, weights_only=False)
model.eval()
return model
DOG_MODEL = load_model("server/models/dogs_model.pth")
CAT_MODEL = load_model("server/models/cats_model.pth")
with open("server/meta/labels_dogs.json", "r") as f:
data_labels = f.read()
labels_dogs = json.loads(data_labels)
with open("server/meta/labels_cats.json", "r") as f:
data_labels = f.read()
labels_cats = json.loads(data_labels)
with open("server/meta/images.json", "r") as f:
IMAGES = json.loads(f.read())
def predict_image(image, model, device="cuda") -> list[tuple]:
img_size = (224, 224)
preprocess = transforms.Compose(
[
transforms.Resize(img_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
input_tensor = preprocess(image)
input_batch = input_tensor.unsqueeze(0).to(device) # Добавляем dimension для батча
with torch.no_grad():
output = model(input_batch)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
k = 5
topk_probs, predicted_idx = torch.topk(probabilities, k)
data = []
for i in range(k):
data.append((predicted_idx[i].item(), float(topk_probs[i].item())))
return data
breed_dir = Path("server/meta/breed_descriptions")
DOGS_BEERS = []
# Идем по каждому текстовому файлу с описанием породы
for breed_file in breed_dir.glob("*.txt"):
breed_name = breed_file.stem # имя файла без расширения - название породы
description = breed_file.read_text(encoding="utf-8") # читаем описание из файла
DOGS_BEERS.append({
"name": breed_name.replace("_", " "),
"alias": breed_file.stem,
"description": description.strip()
})
DOGS_BEERS.sort(key=lambda b: b["name"])
class BeerdsController(Controller):
path = "/beerds"
@post("/dogs")
async def beerds_dogs(
self, data: UploadFile = Body(media_type=RequestEncodingType.MULTI_PART)
) -> dict:
body = await data.read()
img_file = Image.open(io.BytesIO(body))
predicted_data = predict_image(img_file, DOG_MODEL, "cpu")
results = {}
images = []
for d in predicted_data:
predicted_idx, probabilities = d
predicted_label = labels_dogs[str(predicted_idx)]
name = predicted_label.replace("_", " ")
images.append({
"name": name,
"url": [f"/static/assets/dog/{predicted_label}/{i}" for i in IMAGES['dog'][predicted_label]]
})
results[probabilities] = name
return {
"results": results,
"images": images,
}
@post("/cats")
async def beerds_cats(
self, data: UploadFile = Body(media_type=RequestEncodingType.MULTI_PART)
) -> dict:
body = await data.read()
img_file = Image.open(io.BytesIO(body))
predicted_data = predict_image(img_file, CAT_MODEL, "cpu")
results = {}
images = []
for d in predicted_data:
predicted_idx, probabilities = d
predicted_label = labels_cats[str(predicted_idx)]
name = predicted_label.replace("_", " ")
images.append({
"name": name,
"url": [f"/static/assets/cat/{predicted_label}/{i}" for i in IMAGES['cat'][predicted_label]]
})
results[probabilities] = predicted_label
return {
"results": results,
"images": images,
}
class BaseController(Controller):
path = "/"
@get("/")
async def dogs(self) -> Template:
return Template(name="dogs.html")
@get("/cats")
async def cats(self) -> Template:
return Template(name="cats.html")
@get("/contacts")
async def contacts(self) -> Template:
return Template(name="contacts.html")
@get("/donate")
async def donate(self) -> Template:
return Template(name="donate.html")
@get("/dogs-characteristics")
async def dogs_characteristics(self) -> Template:
return Template(name="dogs-characteristics.html", context={"breeds": DOGS_BEERS})
@get("/dogs-characteristics/{name:str}")
async def beer_description(self, name: str) -> Template:
data = [b for b in DOGS_BEERS if b.get("alias") == name]
if len(data) == 0:
raise HTTPException(status_code=404, detail="Порода не найдена")
return Template(name="beers-description.html", context={
"text": markdown.markdown(data[0].get("description")),
"title": data[0].get("name"),
"images": [f"/static/assets/dog/{name}/{i}" for i in IMAGES['dog'][name]],
})
@get("/sitemap.xml", media_type=MediaType.XML)
async def sitemaps(self) -> bytes:
lastmod = "2025-04-21T19:01:03+00:00"
beers_url = ""
for b in DOGS_BEERS:
beers_url += f'''
<url>
<loc>https://xn-----6kcp3cadbabfh8a0a.xn--p1ai/dogs-characteristics/{b.get("alias")}</loc>
<lastmod>{lastmod}</lastmod>
</url>
'''
return f"""<?xml version="1.0" encoding="UTF-8"?>
<urlset
xmlns="http://www.sitemaps.org/schemas/sitemap/0.9"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://www.sitemaps.org/schemas/sitemap/0.9
http://www.sitemaps.org/schemas/sitemap/0.9/sitemap.xsd">
<!-- created with Free Online Sitemap Generator www.xml-sitemaps.com -->
<url>
<loc>https://xn-----6kcp3cadbabfh8a0a.xn--p1ai/</loc>
<lastmod>{lastmod}</lastmod>
</url>
<url>
<loc>https://xn-----6kcp3cadbabfh8a0a.xn--p1ai/cats</loc>
<lastmod>{lastmod}</lastmod>
</url>
<url>
<loc>https://xn-----6kcp3cadbabfh8a0a.xn--p1ai/donate</loc>
<lastmod>{lastmod}</lastmod>
</url>
<url>
<loc>https://xn-----6kcp3cadbabfh8a0a.xn--p1ai/dogs-characteristics</loc>
<lastmod>{lastmod}</lastmod>
</url>
{beers_url}
</urlset>
""".encode()
@get("/robots.txt", media_type=MediaType.TEXT)
async def robots(self) -> str:
return """
User-agent: *
Allow: /
Sitemap: https://xn-----6kcp3cadbabfh8a0a.xn--p1ai/sitemap.xml
"""
app = Starlite(
debug=True,
route_handlers=[BeerdsController, BaseController],
static_files_config=[
StaticFilesConfig(directories=[Path("server/static")], path="/static"),
],
template_config=TemplateConfig(
directory=Path("server/templates"),
engine=JinjaTemplateEngine,
),
)