cats + models
This commit is contained in:
parent
61669ea702
commit
54642a15f5
9
Makefile
9
Makefile
|
|
@ -1,11 +1,14 @@
|
|||
api:
|
||||
uv run granian --interface asgi server.main:app
|
||||
|
||||
runml:
|
||||
uv run ml/beerds.py
|
||||
dog-train:
|
||||
uv run ml/dogs.py
|
||||
|
||||
cat-train:
|
||||
uv run ml/cats.py
|
||||
|
||||
format:
|
||||
uv run ruff format app
|
||||
uv run ruff format
|
||||
|
||||
lint:
|
||||
uv run mypy ./ --explicit-package-bases;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,44 @@
|
|||
import json
|
||||
|
||||
from PIL import ImageFile
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
import torch.nn as nn
|
||||
from torchvision.datasets import ImageFolder # type: ignore
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from train import get_labels, load_model, get_loaders, train, show, DEVICE
|
||||
|
||||
print(f"Using device: {DEVICE}")
|
||||
IMG_SIZE = (180, 180)
|
||||
INPUT_DIR = "assets/cat"
|
||||
NUM_EPOCHS = 10
|
||||
MODEL_NAME = "cats_model.pth"
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Инициализация данных и модели
|
||||
labels_dict: dict[int, str]
|
||||
dataset: ImageFolder
|
||||
labels_dict, dataset = get_labels(INPUT_DIR, IMG_SIZE)
|
||||
with open("labels_cats.json", "w") as f:
|
||||
f.write(json.dumps(labels_dict))
|
||||
|
||||
model: nn.Module = load_model(MODEL_NAME, labels_dict).to(DEVICE)
|
||||
|
||||
# Подготовка данных
|
||||
train_loader: DataLoader
|
||||
val_loader: DataLoader
|
||||
train_loader, val_loader = get_loaders(dataset)
|
||||
|
||||
# Обучение модели
|
||||
val_acc_history, train_acc_history, val_loss_history, train_loss_history = train(
|
||||
NUM_EPOCHS, MODEL_NAME, model, train_loader, val_loader
|
||||
)
|
||||
|
||||
# Визуализация результатов
|
||||
show(
|
||||
NUM_EPOCHS,
|
||||
val_acc_history,
|
||||
train_acc_history,
|
||||
val_loss_history,
|
||||
train_loss_history,
|
||||
)
|
||||
177
ml/dogs.py
177
ml/dogs.py
|
|
@ -1,165 +1,44 @@
|
|||
import os
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
import json
|
||||
|
||||
from PIL import ImageFile
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
import torch.nn as nn
|
||||
from torchvision.datasets import ImageFolder # type: ignore
|
||||
from torch.utils.data import Dataset, DataLoader, random_split
|
||||
from torchvision import transforms # type: ignore
|
||||
import torchvision
|
||||
from typing import Tuple
|
||||
from torchvision.datasets import ImageFolder # type: ignore
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from train import get_labels, load_model, get_loaders, train, show, DEVICE
|
||||
|
||||
# Настройка устройства для вычислений
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"Using device: {DEVICE}")
|
||||
IMG_SIZE = (200, 200)
|
||||
IMG_SIZE = (180, 180)
|
||||
INPUT_DIR = "assets/dog"
|
||||
NUM_EPOCHS = 90
|
||||
|
||||
def get_labels(input_dir, img_size):
|
||||
# Преобразования изображений
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize(img_size),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
])
|
||||
dataset = ImageFolder(root=input_dir, transform=transform)
|
||||
|
||||
# Создание labels_dict для соответствия классов и индексов
|
||||
labels_dict = {idx: class_name for idx, class_name in enumerate(dataset.classes)}
|
||||
return labels_dict, dataset
|
||||
|
||||
def get_loaders(dataset: Dataset) -> Tuple[DataLoader, DataLoader]:
|
||||
# Разделение данных на тренировочные и валидационные
|
||||
train_size = int(0.8 * len(dataset))
|
||||
val_size = len(dataset) - train_size
|
||||
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
|
||||
|
||||
# Загрузчики данных
|
||||
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
|
||||
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
|
||||
return train_loader, val_loader
|
||||
|
||||
|
||||
def load_model(model_path: str, device: str = 'cuda') -> nn.Module:
|
||||
if not os.path.isfile(model_path):
|
||||
print("Start new model")
|
||||
model = torchvision.models.resnet50(pretrained=True)
|
||||
model.fc = torch.nn.Linear(model.fc.in_features, len(labels_dict))
|
||||
return model
|
||||
model = torch.load(model_path, map_location=device, weights_only=False)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
def train(num_epochs: int, model: nn.Module, train_loader: DataLoader, val_loader: DataLoader) -> Tuple[list[float], list[float], list[float], list[float]]:
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
optimizer = torch.optim.Adam(model.fc.parameters(), lr=0.001, weight_decay=0.001)
|
||||
# История метрик
|
||||
train_loss_history = []
|
||||
train_acc_history = []
|
||||
val_loss_history = []
|
||||
val_acc_history = []
|
||||
# Обучение с проверкой и сохранением лучшей модели
|
||||
best_val_loss = float('inf')
|
||||
for epoch in range(num_epochs):
|
||||
model.train()
|
||||
running_loss = 0.0
|
||||
correct = 0
|
||||
total = 0
|
||||
|
||||
# Обучение на тренировочных данных
|
||||
for inputs, labels in train_loader:
|
||||
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
|
||||
|
||||
optimizer.zero_grad()
|
||||
outputs = model(inputs)
|
||||
loss = criterion(outputs, labels)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
running_loss += loss.item()
|
||||
_, predicted = outputs.max(1)
|
||||
total += labels.size(0)
|
||||
correct += predicted.eq(labels).sum().item()
|
||||
|
||||
train_loss = running_loss / len(train_loader)
|
||||
train_acc = 100. * correct / total
|
||||
train_loss_history.append(train_loss)
|
||||
train_acc_history.append(train_acc)
|
||||
print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.2f}%")
|
||||
|
||||
# Оценка на валидационных данных
|
||||
model.eval()
|
||||
val_loss = 0.0
|
||||
correct = 0
|
||||
total = 0
|
||||
with torch.no_grad():
|
||||
for inputs, labels in val_loader:
|
||||
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
|
||||
outputs = model(inputs)
|
||||
loss = criterion(outputs, labels)
|
||||
val_loss += loss.item()
|
||||
_, predicted = outputs.max(1)
|
||||
total += labels.size(0)
|
||||
correct += predicted.eq(labels).sum().item()
|
||||
|
||||
val_loss /= len(val_loader)
|
||||
val_acc = 100. * correct / total
|
||||
val_loss_history.append(val_loss)
|
||||
val_acc_history.append(val_acc)
|
||||
if val_loss < best_val_loss:
|
||||
best_val_loss = val_loss
|
||||
print("save model")
|
||||
torch.save(model, "full_model.pth")
|
||||
print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.2f}%")
|
||||
return val_acc_history, train_acc_history, val_loss_history, train_loss_history
|
||||
|
||||
|
||||
def show(num_epochs: int,
|
||||
val_acc_history: list[float],
|
||||
train_acc_history: list[float],
|
||||
val_loss_history: list[float],
|
||||
train_loss_history: list[float]):
|
||||
|
||||
# Построение графиков
|
||||
epochs = range(1, num_epochs + 1)
|
||||
|
||||
# График точности
|
||||
plt.figure(figsize=(10, 5))
|
||||
plt.plot(epochs, train_acc_history, "bo-", label="Точность на обучении")
|
||||
plt.plot(epochs, val_acc_history, "ro-", label="Точность на валидации")
|
||||
plt.title("Точность на этапах обучения и проверки")
|
||||
plt.xlabel("Эпохи")
|
||||
plt.ylabel("Точность (%)")
|
||||
plt.legend()
|
||||
plt.grid()
|
||||
plt.show()
|
||||
|
||||
# График потерь
|
||||
plt.figure(figsize=(10, 5))
|
||||
plt.plot(epochs, train_loss_history, "bo-", label="Потери на обучении")
|
||||
plt.plot(epochs, val_loss_history, "ro-", label="Потери на валидации")
|
||||
plt.title("Потери на этапах обучения и проверки")
|
||||
plt.xlabel("Эпохи")
|
||||
plt.ylabel("Потери")
|
||||
plt.legend()
|
||||
plt.grid()
|
||||
plt.show()
|
||||
NUM_EPOCHS = 10
|
||||
MODEL_NAME = "dogs_model.pth"
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Инициализация данных и модели
|
||||
labels_dict: dict[int, str]
|
||||
dataset: ImageFolder
|
||||
labels_dict, dataset = get_labels(INPUT_DIR, IMG_SIZE)
|
||||
|
||||
model: nn.Module = load_model("full_model.pth").to(DEVICE)
|
||||
|
||||
with open("labels_dogs.json", "w") as f:
|
||||
f.write(json.dumps(labels_dict))
|
||||
|
||||
model: nn.Module = load_model(MODEL_NAME, labels_dict).to(DEVICE)
|
||||
|
||||
# Подготовка данных
|
||||
train_loader: DataLoader
|
||||
val_loader: DataLoader
|
||||
train_loader, val_loader = get_loaders(dataset)
|
||||
|
||||
|
||||
# Обучение модели
|
||||
val_acc_history, train_acc_history, val_loss_history, train_loss_history = train(NUM_EPOCHS, model, train_loader, val_loader)
|
||||
|
||||
val_acc_history, train_acc_history, val_loss_history, train_loss_history = train(
|
||||
NUM_EPOCHS, MODEL_NAME, model, train_loader, val_loader
|
||||
)
|
||||
|
||||
# Визуализация результатов
|
||||
show(NUM_EPOCHS, val_acc_history, train_acc_history, val_loss_history, train_loss_history)
|
||||
show(
|
||||
NUM_EPOCHS,
|
||||
val_acc_history,
|
||||
train_acc_history,
|
||||
val_loss_history,
|
||||
train_loss_history,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,52 +1,55 @@
|
|||
|
||||
import torch
|
||||
from torchvision import transforms # type: ignore
|
||||
import torch.nn.functional as F
|
||||
from torchvision import transforms
|
||||
from PIL import Image
|
||||
import json
|
||||
|
||||
|
||||
# Создание labels_dict для соответствия классов и индексов
|
||||
with open("labels.json", "r") as f:
|
||||
with open("labels.json", "r") as f:
|
||||
data_labels = f.read()
|
||||
labels_dict = json.loads(data_labels)
|
||||
|
||||
|
||||
def load_model(model_path,device='cuda'):
|
||||
def load_model(model_path, device="cuda"):
|
||||
model = torch.load(model_path, map_location=device, weights_only=False)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
# Инициализация
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
model = load_model('full_model.pth', device=device)
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = load_model("full_model.pth", device=device)
|
||||
|
||||
|
||||
# Преобразования для изображения (адаптируйте под ваш случай)
|
||||
# Преобразования изображений
|
||||
def predict_image(image_path, model, device='cuda'):
|
||||
img_size = (200, 200)
|
||||
preprocess = transforms.Compose([
|
||||
transforms.Resize(img_size),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
])
|
||||
image = Image.open(image_path).convert('RGB')
|
||||
def predict_image(image_path, model, device="cuda"):
|
||||
img_size = (180, 180)
|
||||
preprocess = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(img_size),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
||||
]
|
||||
)
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
input_tensor = preprocess(image)
|
||||
input_batch = input_tensor.unsqueeze(0).to(device) # Добавляем dimension для батча
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(input_batch)
|
||||
|
||||
|
||||
probabilities = F.softmax(output[0], dim=0)
|
||||
|
||||
|
||||
_, predicted_idx = torch.max(probabilities, 0)
|
||||
return predicted_idx.item(), probabilities.cpu().numpy()
|
||||
|
||||
|
||||
# Пример использования
|
||||
image_path = 'assets/test/photo_2023-04-25_10-02-25.jpg'
|
||||
image_path = "assets/test/photo_2023-04-25_10-02-25.jpg"
|
||||
predicted_idx, probabilities = predict_image(image_path, model, device)
|
||||
|
||||
# Предполагая, что labels_dict - словарь вида {индекс: 'название_класса'}
|
||||
predicted_label = labels_dict[str(predicted_idx)]
|
||||
print(f'Predicted class: {predicted_label} (prob: {probabilities[predicted_idx]:.2f})')
|
||||
print(f"Predicted class: {predicted_label} (prob: {probabilities[predicted_idx]:.2f})")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,156 @@
|
|||
import os
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision.datasets import ImageFolder # type: ignore
|
||||
from torch.utils.data import Dataset, DataLoader, random_split
|
||||
from torchvision import transforms # type: ignore
|
||||
import torchvision
|
||||
from torchvision.models import ResNet50_Weights
|
||||
from typing import Tuple
|
||||
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
def get_labels(input_dir, img_size):
|
||||
# Преобразования изображений
|
||||
transform = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(img_size),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
||||
]
|
||||
)
|
||||
dataset = ImageFolder(root=input_dir, transform=transform)
|
||||
|
||||
# Создание labels_dict для соответствия классов и индексов
|
||||
labels_dict = {idx: class_name for idx, class_name in enumerate(dataset.classes)}
|
||||
return labels_dict, dataset
|
||||
|
||||
|
||||
def get_loaders(dataset: Dataset) -> Tuple[DataLoader, DataLoader]:
|
||||
# Разделение данных на тренировочные и валидационные
|
||||
train_size = int(0.8 * float(len(dataset))) # type: ignore[arg-type]
|
||||
val_size = len(dataset) - train_size # type: ignore[arg-type]
|
||||
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
|
||||
|
||||
# Загрузчики данных
|
||||
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
|
||||
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
|
||||
return train_loader, val_loader
|
||||
|
||||
|
||||
def load_model(model_path: str, labels_dict: dict, device: str = "cuda") -> nn.Module:
|
||||
if not os.path.isfile(model_path):
|
||||
print("Start new model")
|
||||
model = torchvision.models.resnet50(weights=ResNet50_Weights.DEFAULT)
|
||||
model.fc = torch.nn.Linear(model.fc.in_features, len(labels_dict))
|
||||
return model
|
||||
model = torch.load(model_path, map_location=device, weights_only=False)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def train(
|
||||
num_epochs: int,
|
||||
model_namme: str,
|
||||
model: nn.Module,
|
||||
train_loader: DataLoader,
|
||||
val_loader: DataLoader,
|
||||
) -> Tuple[list[float], list[float], list[float], list[float]]:
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
optimizer = torch.optim.Adam(model.fc.parameters(), lr=0.001, weight_decay=0.001) # type: ignore[union-attr]
|
||||
# История метрик
|
||||
train_loss_history = []
|
||||
train_acc_history = []
|
||||
val_loss_history = []
|
||||
val_acc_history = []
|
||||
# Обучение с проверкой и сохранением лучшей модели
|
||||
best_val_loss = float("inf")
|
||||
for epoch in range(num_epochs):
|
||||
model.train()
|
||||
running_loss = 0.0
|
||||
correct = 0
|
||||
total = 0
|
||||
|
||||
# Обучение на тренировочных данных
|
||||
for inputs, labels in train_loader:
|
||||
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
|
||||
|
||||
optimizer.zero_grad()
|
||||
outputs = model(inputs)
|
||||
loss = criterion(outputs, labels)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
running_loss += loss.item()
|
||||
_, predicted = outputs.max(1)
|
||||
total += labels.size(0)
|
||||
correct += predicted.eq(labels).sum().item()
|
||||
|
||||
train_loss = running_loss / len(train_loader)
|
||||
train_acc = 100.0 * correct / total
|
||||
train_loss_history.append(train_loss)
|
||||
train_acc_history.append(train_acc)
|
||||
print(
|
||||
f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.2f}%"
|
||||
)
|
||||
|
||||
# Оценка на валидационных данных
|
||||
model.eval()
|
||||
val_loss = 0.0
|
||||
correct = 0
|
||||
total = 0
|
||||
with torch.no_grad():
|
||||
for inputs, labels in val_loader:
|
||||
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
|
||||
outputs = model(inputs)
|
||||
loss = criterion(outputs, labels)
|
||||
val_loss += loss.item()
|
||||
_, predicted = outputs.max(1)
|
||||
total += labels.size(0)
|
||||
correct += predicted.eq(labels).sum().item()
|
||||
|
||||
val_loss /= len(val_loader)
|
||||
val_acc = 100.0 * correct / total
|
||||
val_loss_history.append(val_loss)
|
||||
val_acc_history.append(val_acc)
|
||||
if val_loss < best_val_loss:
|
||||
best_val_loss = val_loss
|
||||
print("save model")
|
||||
torch.save(model, model_namme)
|
||||
print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.2f}%")
|
||||
return val_acc_history, train_acc_history, val_loss_history, train_loss_history
|
||||
|
||||
|
||||
def show(
|
||||
num_epochs: int,
|
||||
val_acc_history: list[float],
|
||||
train_acc_history: list[float],
|
||||
val_loss_history: list[float],
|
||||
train_loss_history: list[float],
|
||||
):
|
||||
# Построение графиков
|
||||
epochs = range(1, num_epochs + 1)
|
||||
|
||||
# График точности
|
||||
plt.figure(figsize=(10, 5))
|
||||
plt.plot(epochs, train_acc_history, "bo-", label="Точность на обучении")
|
||||
plt.plot(epochs, val_acc_history, "ro-", label="Точность на валидации")
|
||||
plt.title("Точность на этапах обучения и проверки")
|
||||
plt.xlabel("Эпохи")
|
||||
plt.ylabel("Точность (%)")
|
||||
plt.legend()
|
||||
plt.grid()
|
||||
plt.show()
|
||||
|
||||
# График потерь
|
||||
plt.figure(figsize=(10, 5))
|
||||
plt.plot(epochs, train_loss_history, "bo-", label="Потери на обучении")
|
||||
plt.plot(epochs, val_loss_history, "ro-", label="Потери на валидации")
|
||||
plt.title("Потери на этапах обучения и проверки")
|
||||
plt.xlabel("Эпохи")
|
||||
plt.ylabel("Потери")
|
||||
plt.legend()
|
||||
plt.grid()
|
||||
plt.show()
|
||||
|
|
@ -15,4 +15,5 @@ dependencies = [
|
|||
"starlite>=1.51.16",
|
||||
"torch>=2.6.0",
|
||||
"torchvision>=0.21.0",
|
||||
"types-requests>=2.32.0.20250328",
|
||||
]
|
||||
|
|
|
|||
125
server/main.py
125
server/main.py
|
|
@ -1,22 +1,30 @@
|
|||
|
||||
from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
from starlite import Starlite, Controller, StaticFilesConfig, get, post, Body, MediaType, RequestEncodingType, Starlite, UploadFile, Template, TemplateConfig
|
||||
from starlite import (
|
||||
Controller,
|
||||
StaticFilesConfig,
|
||||
get,
|
||||
post,
|
||||
Body,
|
||||
MediaType,
|
||||
RequestEncodingType,
|
||||
Starlite,
|
||||
UploadFile,
|
||||
Template,
|
||||
TemplateConfig,
|
||||
)
|
||||
from starlite.contrib.jinja import JinjaTemplateEngine
|
||||
import numpy as np
|
||||
import io
|
||||
import os
|
||||
import json
|
||||
import requests
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
|
||||
|
||||
import torch
|
||||
from torchvision import transforms # type: ignore
|
||||
import torch.nn.functional as F
|
||||
|
||||
model_name = "models/beerd_imagenet_02_05_2023.keras"
|
||||
test_model_imagenet = keras.models.load_model(model_name)
|
||||
|
||||
model_name = "./models/beerd_25_04_2023.keras"
|
||||
test_model = keras.models.load_model(model_name)
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
|
||||
|
||||
dict_names = {}
|
||||
with open("beerds.json", "r") as f:
|
||||
|
|
@ -25,7 +33,7 @@ for key in dict_names:
|
|||
dict_names[key] = dict_names[key].replace("_", " ")
|
||||
|
||||
VK_URL = "https://api.vk.com/method/"
|
||||
TOKEN = ""
|
||||
TOKEN = os.getenv("VK_TOKEN")
|
||||
headers = {"Authorization": f"Bearer {TOKEN}"}
|
||||
group_id = 220240483
|
||||
postfix = "?v=5.131"
|
||||
|
|
@ -36,7 +44,8 @@ def get_images():
|
|||
global IMAGES
|
||||
|
||||
r = requests.get(
|
||||
f"{VK_URL}photos.getAll{postfix}&access_token={TOKEN}&owner_id=-{group_id}&count=200")
|
||||
f"{VK_URL}photos.getAll{postfix}&access_token={TOKEN}&owner_id=-{group_id}&count=200"
|
||||
)
|
||||
items = r.json().get("response").get("items")
|
||||
for item in items:
|
||||
for s in item.get("sizes"):
|
||||
|
|
@ -46,61 +55,59 @@ def get_images():
|
|||
break
|
||||
|
||||
|
||||
def load_model(model_path, device="cpu"):
|
||||
model = torch.load(model_path, map_location=device, weights_only=False)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
get_images()
|
||||
MODEL = load_model("./models/dogs_model.pth")
|
||||
|
||||
with open("labels.json", "r") as f:
|
||||
data_labels = f.read()
|
||||
labels_dict = json.loads(data_labels)
|
||||
|
||||
|
||||
def predict_image(image, model, device="cuda"):
|
||||
img_size = (180, 180)
|
||||
preprocess = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(img_size),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
||||
]
|
||||
)
|
||||
input_tensor = preprocess(image)
|
||||
input_batch = input_tensor.unsqueeze(0).to(device) # Добавляем dimension для батча
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(input_batch)
|
||||
|
||||
probabilities = F.softmax(output[0], dim=0)
|
||||
|
||||
_, predicted_idx = torch.max(probabilities, 0)
|
||||
return predicted_idx.item(), probabilities.cpu().numpy()
|
||||
|
||||
|
||||
class BeerdsController(Controller):
|
||||
path = "/breeds"
|
||||
|
||||
@post("/", media_type=MediaType.TEXT)
|
||||
async def beeds(self, data: UploadFile = Body(media_type=RequestEncodingType.MULTI_PART)) -> dict:
|
||||
async def beeds(
|
||||
self, data: UploadFile = Body(media_type=RequestEncodingType.MULTI_PART)
|
||||
) -> dict:
|
||||
body = await data.read()
|
||||
|
||||
img = Image.open(io.BytesIO(body))
|
||||
img = img.convert('RGB')
|
||||
img_file = Image.open(io.BytesIO(body))
|
||||
predicted_idx, probabilities = predict_image(img_file, MODEL, "cpu")
|
||||
predicted_label = labels_dict[str(predicted_idx)]
|
||||
|
||||
img_net = img.resize((180, 180, ), Image.BILINEAR)
|
||||
img_array = img_to_array(img_net)
|
||||
test_loss_image_net = test_model_imagenet.predict(
|
||||
np.expand_dims(img_array, 0))
|
||||
|
||||
img = img.resize((200, 200, ), Image.BILINEAR)
|
||||
img_array = img_to_array(img)
|
||||
test_loss = test_model.predict(np.expand_dims(img_array, 0))
|
||||
|
||||
result = {}
|
||||
for i, val in enumerate(test_loss[0]):
|
||||
if val <= 0.09:
|
||||
continue
|
||||
result[val] = dict_names[str(i)]
|
||||
|
||||
result_net = {}
|
||||
for i, val in enumerate(test_loss_image_net[0]):
|
||||
if val <= 0.09:
|
||||
continue
|
||||
result_net[val] = dict_names[str(i)]
|
||||
items_one = dict(sorted(result.items(), reverse=True))
|
||||
items_two = dict(sorted(result_net.items(), reverse=True))
|
||||
images = []
|
||||
for item in items_one:
|
||||
name = items_one[item].replace("_", " ")
|
||||
if name not in IMAGES:
|
||||
continue
|
||||
images.append({"name": name, "url": IMAGES[name]})
|
||||
for item in items_two:
|
||||
name = items_two[item].replace("_", " ")
|
||||
if name not in IMAGES:
|
||||
continue
|
||||
images.append({"name": name, "url": IMAGES[name]})
|
||||
images = [{"name": predicted_label, "url": IMAGES[predicted_label]}]
|
||||
return {
|
||||
"results": items_one,
|
||||
"results_net": items_two,
|
||||
"results": {probabilities: predicted_label},
|
||||
"images": images,
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class BaseController(Controller):
|
||||
|
|
@ -112,7 +119,7 @@ class BaseController(Controller):
|
|||
|
||||
@get("/sitemap.xml", media_type=MediaType.XML)
|
||||
async def sitemaps(self) -> bytes:
|
||||
return '''<?xml version="1.0" encoding="UTF-8"?>
|
||||
return """<?xml version="1.0" encoding="UTF-8"?>
|
||||
<urlset
|
||||
xmlns="http://www.sitemaps.org/schemas/sitemap/0.9"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
|
|
@ -128,24 +135,22 @@ class BaseController(Controller):
|
|||
|
||||
|
||||
</urlset>
|
||||
'''.encode()
|
||||
|
||||
""".encode()
|
||||
|
||||
@get("/robots.txt", media_type=MediaType.TEXT)
|
||||
async def robots(self) -> str:
|
||||
return '''
|
||||
return """
|
||||
User-agent: *
|
||||
Allow: /
|
||||
|
||||
Sitemap: https://xn-----6kcp3cadbabfh8a0a.xn--p1ai/sitemap.xml
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
app = Starlite(
|
||||
route_handlers=[BeerdsController, BaseController],
|
||||
static_files_config=[
|
||||
StaticFilesConfig(directories=["static"], path="/static"),
|
||||
|
||||
StaticFilesConfig(directories=[Path("static")], path="/static"),
|
||||
],
|
||||
template_config=TemplateConfig(
|
||||
directory=Path("templates"),
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
Binary file not shown.
Binary file not shown.
23
uv.lock
23
uv.lock
|
|
@ -22,6 +22,7 @@ dependencies = [
|
|||
{ name = "starlite" },
|
||||
{ name = "torch" },
|
||||
{ name = "torchvision" },
|
||||
{ name = "types-requests" },
|
||||
]
|
||||
|
||||
[package.metadata]
|
||||
|
|
@ -36,6 +37,7 @@ requires-dist = [
|
|||
{ name = "starlite", specifier = ">=1.51.16" },
|
||||
{ name = "torch", specifier = ">=2.6.0" },
|
||||
{ name = "torchvision", specifier = ">=0.21.0" },
|
||||
{ name = "types-requests", specifier = ">=2.32.0.20250328" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -1215,6 +1217,18 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/c7/30/37a3384d1e2e9320331baca41e835e90a3767303642c7a80d4510152cbcf/triton-3.2.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e5dfa23ba84541d7c0a531dfce76d8bcd19159d50a4a8b14ad01e91734a5c1b0", size = 253154278 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "types-requests"
|
||||
version = "2.32.0.20250328"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "urllib3" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/00/7d/eb174f74e3f5634eaacb38031bbe467dfe2e545bc255e5c90096ec46bc46/types_requests-2.32.0.20250328.tar.gz", hash = "sha256:c9e67228ea103bd811c96984fac36ed2ae8da87a36a633964a21f199d60baf32", size = 22995 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/cc/15/3700282a9d4ea3b37044264d3e4d1b1f0095a4ebf860a99914fd544e3be3/types_requests-2.32.0.20250328-py3-none-any.whl", hash = "sha256:72ff80f84b15eb3aa7a8e2625fffb6a93f2ad5a0c20215fc1dcfa61117bcb2a2", size = 20663 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typing-extensions"
|
||||
version = "4.13.2"
|
||||
|
|
@ -1232,3 +1246,12 @@ sdist = { url = "https://files.pythonhosted.org/packages/95/32/1a225d6164441be76
|
|||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/5c/23/c7abc0ca0a1526a0774eca151daeb8de62ec457e77262b66b359c3c7679e/tzdata-2025.2-py2.py3-none-any.whl", hash = "sha256:1a403fada01ff9221ca8044d701868fa132215d84beb92242d9acd2147f667a8", size = 347839 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "urllib3"
|
||||
version = "2.4.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/8a/78/16493d9c386d8e60e442a35feac5e00f0913c0f4b7c217c11e8ec2ff53e0/urllib3-2.4.0.tar.gz", hash = "sha256:414bc6535b787febd7567804cc015fee39daab8ad86268f1310a9250697de466", size = 390672 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/6b/11/cc635220681e93a0183390e26485430ca2c7b5f9d33b15c74c2861cb8091/urllib3-2.4.0-py3-none-any.whl", hash = "sha256:4e16665048960a0900c702d4a66415956a584919c03361cac9f1df5c5dd7e813", size = 128680 },
|
||||
]
|
||||
|
|
|
|||
|
|
@ -13,7 +13,8 @@ dir = "../assets/dog"
|
|||
list_labels = [fname for fname in os.listdir(dir)]
|
||||
|
||||
r = requests.get(
|
||||
f"{VK_URL}photos.getAll{postfix}&access_token={TOKEN}&owner_id=-{group_id}&count=200")
|
||||
f"{VK_URL}photos.getAll{postfix}&access_token={TOKEN}&owner_id=-{group_id}&count=200"
|
||||
)
|
||||
if "error" in r.json():
|
||||
print("error", r.json())
|
||||
exit()
|
||||
|
|
@ -40,23 +41,36 @@ for name in list_labels:
|
|||
max_index = i
|
||||
image_name = list_data[max_index]
|
||||
file_stats = os.stat(os.path.join(dir, name, image_name))
|
||||
r = requests.post(f"{VK_URL}photos.createAlbum{postfix}", data={
|
||||
"title": name.replace("_", " "), "group_id": group_id}, headers=headers)
|
||||
r = requests.post(
|
||||
f"{VK_URL}photos.createAlbum{postfix}",
|
||||
data={"title": name.replace("_", " "), "group_id": group_id},
|
||||
headers=headers,
|
||||
)
|
||||
if "error" in r.json():
|
||||
print("error", r.json())
|
||||
break
|
||||
album_id = r.json().get("response").get("id")
|
||||
r = requests.get(
|
||||
f"{VK_URL}photos.getUploadServer{postfix}&album_id={album_id}&access_token={TOKEN}&group_id={group_id}")
|
||||
f"{VK_URL}photos.getUploadServer{postfix}&album_id={album_id}&access_token={TOKEN}&group_id={group_id}"
|
||||
)
|
||||
url = r.json().get("response").get("upload_url")
|
||||
files = {'file1': open(os.path.join(dir, name, image_name), 'rb')}
|
||||
files = {"file1": open(os.path.join(dir, name, image_name), "rb")}
|
||||
r = requests.post(url, files=files)
|
||||
server = r.json().get("server")
|
||||
photos_list = r.json().get("photos_list")
|
||||
hash_data = r.json().get("hash")
|
||||
aid = r.json().get("aid")
|
||||
r = requests.post(f"{VK_URL}photos.save{postfix}&hash={hash_data}", data={"album_id": aid, "server": server,
|
||||
"photos_list": photos_list, "caption": name.replace("_", " "), "group_id": group_id}, headers=headers)
|
||||
r = requests.post(
|
||||
f"{VK_URL}photos.save{postfix}&hash={hash_data}",
|
||||
data={
|
||||
"album_id": aid,
|
||||
"server": server,
|
||||
"photos_list": photos_list,
|
||||
"caption": name.replace("_", " "),
|
||||
"group_id": group_id,
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
if "error" in r.json():
|
||||
print("error", r.json())
|
||||
break
|
||||
|
|
|
|||
Loading…
Reference in New Issue