135 lines
3.4 KiB
Python
135 lines
3.4 KiB
Python
from pathlib import Path
|
|
|
|
from PIL import Image
|
|
from starlite import (
|
|
Controller,
|
|
StaticFilesConfig,
|
|
get,
|
|
post,
|
|
Body,
|
|
MediaType,
|
|
RequestEncodingType,
|
|
Starlite,
|
|
UploadFile,
|
|
Template,
|
|
TemplateConfig,
|
|
)
|
|
from starlite.contrib.jinja import JinjaTemplateEngine
|
|
import io
|
|
import os
|
|
import json
|
|
|
|
import torch
|
|
from torchvision import transforms # type: ignore
|
|
import torch.nn.functional as F
|
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
|
|
|
|
|
|
def load_model(model_path, device="cpu"):
|
|
model = torch.load(model_path, map_location=device, weights_only=False)
|
|
model.eval()
|
|
return model
|
|
|
|
|
|
with open("server/meta/images.json", "r") as f:
|
|
IMAGES = json.loads(f.read())
|
|
|
|
|
|
MODEL = load_model("server/models/dogs_model.pth")
|
|
|
|
with open("server/meta/labels_dogs.json", "r") as f:
|
|
data_labels = f.read()
|
|
labels_dict = json.loads(data_labels)
|
|
|
|
|
|
def predict_image(image, model, device="cuda"):
|
|
img_size = (180, 180)
|
|
preprocess = transforms.Compose(
|
|
[
|
|
transforms.Resize(img_size),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
|
]
|
|
)
|
|
input_tensor = preprocess(image)
|
|
input_batch = input_tensor.unsqueeze(0).to(device) # Добавляем dimension для батча
|
|
|
|
with torch.no_grad():
|
|
output = model(input_batch)
|
|
|
|
probabilities = F.softmax(output[0], dim=0)
|
|
|
|
_, predicted_idx = torch.max(probabilities, 0)
|
|
return predicted_idx.item(), probabilities.cpu().numpy()
|
|
|
|
|
|
class BeerdsController(Controller):
|
|
path = "/beerds"
|
|
|
|
@post("/dogs")
|
|
async def beerds(
|
|
self, data: UploadFile = Body(media_type=RequestEncodingType.MULTI_PART)
|
|
) -> dict:
|
|
body = await data.read()
|
|
|
|
img_file = Image.open(io.BytesIO(body))
|
|
predicted_idx, probabilities = predict_image(img_file, MODEL, "cpu")
|
|
predicted_label = labels_dict[str(predicted_idx)]
|
|
|
|
images = [{"name": predicted_label, "url": IMAGES[predicted_label]}]
|
|
return {
|
|
"results": {float(probabilities[0]): predicted_label},
|
|
"images": images,
|
|
}
|
|
|
|
|
|
class BaseController(Controller):
|
|
path = "/"
|
|
|
|
@get("/")
|
|
async def main(self) -> Template:
|
|
return Template(name="index.html")
|
|
|
|
@get("/sitemap.xml", media_type=MediaType.XML)
|
|
async def sitemaps(self) -> bytes:
|
|
return """<?xml version="1.0" encoding="UTF-8"?>
|
|
<urlset
|
|
xmlns="http://www.sitemaps.org/schemas/sitemap/0.9"
|
|
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
|
xsi:schemaLocation="http://www.sitemaps.org/schemas/sitemap/0.9
|
|
http://www.sitemaps.org/schemas/sitemap/0.9/sitemap.xsd">
|
|
<!-- created with Free Online Sitemap Generator www.xml-sitemaps.com -->
|
|
|
|
|
|
<url>
|
|
<loc>https://xn-----6kcp3cadbabfh8a0a.xn--p1ai/</loc>
|
|
<lastmod>2023-05-01T19:01:03+00:00</lastmod>
|
|
</url>
|
|
|
|
|
|
</urlset>
|
|
""".encode()
|
|
|
|
@get("/robots.txt", media_type=MediaType.TEXT)
|
|
async def robots(self) -> str:
|
|
return """
|
|
User-agent: *
|
|
Allow: /
|
|
|
|
Sitemap: https://xn-----6kcp3cadbabfh8a0a.xn--p1ai/sitemap.xml
|
|
"""
|
|
|
|
|
|
app = Starlite(
|
|
debug=True,
|
|
route_handlers=[BeerdsController, BaseController],
|
|
static_files_config=[
|
|
StaticFilesConfig(directories=[Path("server/static")], path="/static"),
|
|
],
|
|
template_config=TemplateConfig(
|
|
directory=Path("server/templates"),
|
|
engine=JinjaTemplateEngine,
|
|
),
|
|
)
|