240 lines
7.1 KiB
Python
240 lines
7.1 KiB
Python
from pathlib import Path
|
||
|
||
import markdown
|
||
from PIL import Image
|
||
from starlite import (
|
||
Controller,
|
||
StaticFilesConfig,
|
||
get,
|
||
post,
|
||
Body,
|
||
MediaType,
|
||
RequestEncodingType,
|
||
Starlite,
|
||
UploadFile,
|
||
Template,
|
||
TemplateConfig,
|
||
HTTPException
|
||
)
|
||
from starlite.contrib.jinja import JinjaTemplateEngine
|
||
import io
|
||
import os
|
||
import json
|
||
|
||
import torch
|
||
from torchvision import transforms # type: ignore
|
||
import torch.nn.functional as F
|
||
|
||
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
|
||
|
||
|
||
def load_model(model_path, device="cpu"):
|
||
model = torch.load(model_path, map_location=device, weights_only=False)
|
||
model.eval()
|
||
return model
|
||
|
||
|
||
DOG_MODEL = load_model("server/models/dogs_model.pth")
|
||
CAT_MODEL = load_model("server/models/cats_model.pth")
|
||
|
||
with open("server/meta/labels_dogs.json", "r") as f:
|
||
data_labels = f.read()
|
||
labels_dogs = json.loads(data_labels)
|
||
|
||
with open("server/meta/labels_cats.json", "r") as f:
|
||
data_labels = f.read()
|
||
labels_cats = json.loads(data_labels)
|
||
|
||
|
||
with open("server/meta/images.json", "r") as f:
|
||
IMAGES = json.loads(f.read())
|
||
|
||
def predict_image(image, model, device="cuda") -> list[tuple]:
|
||
img_size = (224, 224)
|
||
preprocess = transforms.Compose(
|
||
[
|
||
transforms.Resize(img_size),
|
||
transforms.ToTensor(),
|
||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
||
]
|
||
)
|
||
input_tensor = preprocess(image)
|
||
input_batch = input_tensor.unsqueeze(0).to(device) # Добавляем dimension для батча
|
||
|
||
with torch.no_grad():
|
||
output = model(input_batch)
|
||
|
||
probabilities = torch.nn.functional.softmax(output[0], dim=0)
|
||
k = 5
|
||
topk_probs, predicted_idx = torch.topk(probabilities, k)
|
||
|
||
data = []
|
||
for i in range(k):
|
||
data.append((predicted_idx[i].item(), float(topk_probs[i].item())))
|
||
return data
|
||
|
||
breed_dir = Path("server/meta/breed_descriptions")
|
||
DOGS_BEERS = []
|
||
|
||
# Идем по каждому текстовому файлу с описанием породы
|
||
for breed_file in breed_dir.glob("*.txt"):
|
||
breed_name = breed_file.stem # имя файла без расширения - название породы
|
||
description = breed_file.read_text(encoding="utf-8") # читаем описание из файла
|
||
DOGS_BEERS.append({
|
||
"name": breed_name.replace("_", " "),
|
||
"alias": breed_file.stem,
|
||
"description": description.strip()
|
||
})
|
||
DOGS_BEERS.sort(key=lambda b: b["name"])
|
||
|
||
class BeerdsController(Controller):
|
||
path = "/beerds"
|
||
|
||
@post("/dogs")
|
||
async def beerds_dogs(
|
||
self, data: UploadFile = Body(media_type=RequestEncodingType.MULTI_PART)
|
||
) -> dict:
|
||
body = await data.read()
|
||
|
||
img_file = Image.open(io.BytesIO(body))
|
||
predicted_data = predict_image(img_file, DOG_MODEL, "cpu")
|
||
results = {}
|
||
images = []
|
||
for d in predicted_data:
|
||
predicted_idx, probabilities = d
|
||
predicted_label = labels_dogs[str(predicted_idx)]
|
||
name = predicted_label.replace("_", " ")
|
||
images.append({
|
||
"name": name,
|
||
"url": [f"/static/assets/dog/{predicted_label}/{i}" for i in IMAGES['dog'][predicted_label]]
|
||
})
|
||
results[probabilities] = name
|
||
return {
|
||
"results": results,
|
||
"images": images,
|
||
}
|
||
|
||
@post("/cats")
|
||
async def beerds_cats(
|
||
self, data: UploadFile = Body(media_type=RequestEncodingType.MULTI_PART)
|
||
) -> dict:
|
||
body = await data.read()
|
||
|
||
img_file = Image.open(io.BytesIO(body))
|
||
predicted_data = predict_image(img_file, CAT_MODEL, "cpu")
|
||
results = {}
|
||
images = []
|
||
for d in predicted_data:
|
||
predicted_idx, probabilities = d
|
||
predicted_label = labels_cats[str(predicted_idx)]
|
||
name = predicted_label.replace("_", " ")
|
||
images.append({
|
||
"name": name,
|
||
"url": [f"/static/assets/cat/{predicted_label}/{i}" for i in IMAGES['cat'][predicted_label]]
|
||
})
|
||
results[probabilities] = predicted_label
|
||
return {
|
||
"results": results,
|
||
"images": images,
|
||
}
|
||
|
||
|
||
class BaseController(Controller):
|
||
path = "/"
|
||
|
||
@get("/")
|
||
async def dogs(self) -> Template:
|
||
return Template(name="dogs.html")
|
||
|
||
@get("/cats")
|
||
async def cats(self) -> Template:
|
||
return Template(name="cats.html")
|
||
|
||
@get("/contacts")
|
||
async def contacts(self) -> Template:
|
||
return Template(name="contacts.html")
|
||
|
||
@get("/donate")
|
||
async def donate(self) -> Template:
|
||
return Template(name="donate.html")
|
||
|
||
@get("/dogs-characteristics")
|
||
async def dogs_characteristics(self) -> Template:
|
||
return Template(name="dogs-characteristics.html", context={"breeds": DOGS_BEERS})
|
||
|
||
@get("/dogs-characteristics/{name:str}")
|
||
async def beer_description(self, name: str) -> Template:
|
||
data = [b for b in DOGS_BEERS if b.get("alias") == name]
|
||
if len(data) == 0:
|
||
raise HTTPException(status_code=404, detail="Порода не найдена")
|
||
return Template(name="beers-description.html", context={
|
||
"text": markdown.markdown(data[0].get("description")),
|
||
"title": data[0].get("name"),
|
||
"images": [f"/static/assets/dog/{name}/{i}" for i in IMAGES['dog'][name]],
|
||
})
|
||
|
||
|
||
@get("/sitemap.xml", media_type=MediaType.XML)
|
||
async def sitemaps(self) -> bytes:
|
||
lastmod = "2025-10-04T19:01:03+00:00"
|
||
beers_url = ""
|
||
for b in DOGS_BEERS:
|
||
beers_url += f'''
|
||
<url>
|
||
<loc>https://xn-----6kcp3cadbabfh8a0a.xn--p1ai/dogs-characteristics/{b.get("alias")}</loc>
|
||
<lastmod>{lastmod}</lastmod>
|
||
</url>
|
||
'''
|
||
return f"""<?xml version="1.0" encoding="UTF-8"?>
|
||
<urlset
|
||
xmlns="http://www.sitemaps.org/schemas/sitemap/0.9"
|
||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||
xsi:schemaLocation="http://www.sitemaps.org/schemas/sitemap/0.9
|
||
http://www.sitemaps.org/schemas/sitemap/0.9/sitemap.xsd">
|
||
<!-- created with Free Online Sitemap Generator www.xml-sitemaps.com -->
|
||
|
||
|
||
<url>
|
||
<loc>https://xn-----6kcp3cadbabfh8a0a.xn--p1ai/</loc>
|
||
<lastmod>{lastmod}</lastmod>
|
||
</url>
|
||
<url>
|
||
<loc>https://xn-----6kcp3cadbabfh8a0a.xn--p1ai/cats</loc>
|
||
<lastmod>{lastmod}</lastmod>
|
||
</url>
|
||
<url>
|
||
<loc>https://xn-----6kcp3cadbabfh8a0a.xn--p1ai/donate</loc>
|
||
<lastmod>{lastmod}</lastmod>
|
||
</url>
|
||
<url>
|
||
<loc>https://xn-----6kcp3cadbabfh8a0a.xn--p1ai/dogs-characteristics</loc>
|
||
<lastmod>{lastmod}</lastmod>
|
||
</url>
|
||
{beers_url}
|
||
|
||
</urlset>
|
||
""".encode()
|
||
|
||
|
||
@get("/robots.txt", media_type=MediaType.TEXT)
|
||
async def robots(self) -> str:
|
||
return """
|
||
User-agent: *
|
||
Allow: /
|
||
|
||
Sitemap: https://xn-----6kcp3cadbabfh8a0a.xn--p1ai/sitemap.xml
|
||
"""
|
||
|
||
|
||
app = Starlite(
|
||
debug=True,
|
||
route_handlers=[BeerdsController, BaseController],
|
||
static_files_config=[
|
||
StaticFilesConfig(directories=[Path("server/static")], path="/static"),
|
||
],
|
||
template_config=TemplateConfig(
|
||
directory=Path("server/templates"),
|
||
engine=JinjaTemplateEngine,
|
||
),
|
||
)
|