160 lines
4.0 KiB
Python
160 lines
4.0 KiB
Python
from pathlib import Path
|
|
|
|
from PIL import Image
|
|
from starlite import (
|
|
Controller,
|
|
StaticFilesConfig,
|
|
get,
|
|
post,
|
|
Body,
|
|
MediaType,
|
|
RequestEncodingType,
|
|
Starlite,
|
|
UploadFile,
|
|
Template,
|
|
TemplateConfig,
|
|
)
|
|
from starlite.contrib.jinja import JinjaTemplateEngine
|
|
import io
|
|
import os
|
|
import json
|
|
import requests
|
|
|
|
import torch
|
|
from torchvision import transforms # type: ignore
|
|
import torch.nn.functional as F
|
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
|
|
|
|
dict_names = {}
|
|
with open("beerds.json", "r") as f:
|
|
dict_names = json.loads(f.read())
|
|
for key in dict_names:
|
|
dict_names[key] = dict_names[key].replace("_", " ")
|
|
|
|
VK_URL = "https://api.vk.com/method/"
|
|
TOKEN = os.getenv("VK_TOKEN")
|
|
headers = {"Authorization": f"Bearer {TOKEN}"}
|
|
group_id = 220240483
|
|
postfix = "?v=5.131"
|
|
IMAGES = {}
|
|
|
|
|
|
def get_images():
|
|
global IMAGES
|
|
|
|
r = requests.get(
|
|
f"{VK_URL}photos.getAll{postfix}&access_token={TOKEN}&owner_id=-{group_id}&count=200"
|
|
)
|
|
items = r.json().get("response").get("items")
|
|
for item in items:
|
|
for s in item.get("sizes"):
|
|
if s.get("type") != "x":
|
|
continue
|
|
IMAGES[item.get("text")] = s.get("url")
|
|
break
|
|
|
|
|
|
def load_model(model_path, device="cpu"):
|
|
model = torch.load(model_path, map_location=device, weights_only=False)
|
|
model.eval()
|
|
return model
|
|
|
|
|
|
get_images()
|
|
MODEL = load_model("./models/dogs_model.pth")
|
|
|
|
with open("labels.json", "r") as f:
|
|
data_labels = f.read()
|
|
labels_dict = json.loads(data_labels)
|
|
|
|
|
|
def predict_image(image, model, device="cuda"):
|
|
img_size = (180, 180)
|
|
preprocess = transforms.Compose(
|
|
[
|
|
transforms.Resize(img_size),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
|
]
|
|
)
|
|
input_tensor = preprocess(image)
|
|
input_batch = input_tensor.unsqueeze(0).to(device) # Добавляем dimension для батча
|
|
|
|
with torch.no_grad():
|
|
output = model(input_batch)
|
|
|
|
probabilities = F.softmax(output[0], dim=0)
|
|
|
|
_, predicted_idx = torch.max(probabilities, 0)
|
|
return predicted_idx.item(), probabilities.cpu().numpy()
|
|
|
|
|
|
class BeerdsController(Controller):
|
|
path = "/breeds"
|
|
|
|
@post("/", media_type=MediaType.TEXT)
|
|
async def beeds(
|
|
self, data: UploadFile = Body(media_type=RequestEncodingType.MULTI_PART)
|
|
) -> dict:
|
|
body = await data.read()
|
|
|
|
img_file = Image.open(io.BytesIO(body))
|
|
predicted_idx, probabilities = predict_image(img_file, MODEL, "cpu")
|
|
predicted_label = labels_dict[str(predicted_idx)]
|
|
|
|
images = [{"name": predicted_label, "url": IMAGES[predicted_label]}]
|
|
return {
|
|
"results": {probabilities: predicted_label},
|
|
"images": images,
|
|
}
|
|
|
|
|
|
class BaseController(Controller):
|
|
path = "/"
|
|
|
|
@get("/")
|
|
async def main(self) -> Template:
|
|
return Template(name="index.html")
|
|
|
|
@get("/sitemap.xml", media_type=MediaType.XML)
|
|
async def sitemaps(self) -> bytes:
|
|
return """<?xml version="1.0" encoding="UTF-8"?>
|
|
<urlset
|
|
xmlns="http://www.sitemaps.org/schemas/sitemap/0.9"
|
|
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
|
xsi:schemaLocation="http://www.sitemaps.org/schemas/sitemap/0.9
|
|
http://www.sitemaps.org/schemas/sitemap/0.9/sitemap.xsd">
|
|
<!-- created with Free Online Sitemap Generator www.xml-sitemaps.com -->
|
|
|
|
|
|
<url>
|
|
<loc>https://xn-----6kcp3cadbabfh8a0a.xn--p1ai/</loc>
|
|
<lastmod>2023-05-01T19:01:03+00:00</lastmod>
|
|
</url>
|
|
|
|
|
|
</urlset>
|
|
""".encode()
|
|
|
|
@get("/robots.txt", media_type=MediaType.TEXT)
|
|
async def robots(self) -> str:
|
|
return """
|
|
User-agent: *
|
|
Allow: /
|
|
|
|
Sitemap: https://xn-----6kcp3cadbabfh8a0a.xn--p1ai/sitemap.xml
|
|
"""
|
|
|
|
|
|
app = Starlite(
|
|
route_handlers=[BeerdsController, BaseController],
|
|
static_files_config=[
|
|
StaticFilesConfig(directories=[Path("static")], path="/static"),
|
|
],
|
|
template_config=TemplateConfig(
|
|
directory=Path("templates"),
|
|
engine=JinjaTemplateEngine,
|
|
),
|
|
)
|