cats + models
This commit is contained in:
parent
61669ea702
commit
54642a15f5
9
Makefile
9
Makefile
|
|
@ -1,11 +1,14 @@
|
||||||
api:
|
api:
|
||||||
uv run granian --interface asgi server.main:app
|
uv run granian --interface asgi server.main:app
|
||||||
|
|
||||||
runml:
|
dog-train:
|
||||||
uv run ml/beerds.py
|
uv run ml/dogs.py
|
||||||
|
|
||||||
|
cat-train:
|
||||||
|
uv run ml/cats.py
|
||||||
|
|
||||||
format:
|
format:
|
||||||
uv run ruff format app
|
uv run ruff format
|
||||||
|
|
||||||
lint:
|
lint:
|
||||||
uv run mypy ./ --explicit-package-bases;
|
uv run mypy ./ --explicit-package-bases;
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,44 @@
|
||||||
|
import json
|
||||||
|
|
||||||
|
from PIL import ImageFile
|
||||||
|
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||||
|
import torch.nn as nn
|
||||||
|
from torchvision.datasets import ImageFolder # type: ignore
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from train import get_labels, load_model, get_loaders, train, show, DEVICE
|
||||||
|
|
||||||
|
print(f"Using device: {DEVICE}")
|
||||||
|
IMG_SIZE = (180, 180)
|
||||||
|
INPUT_DIR = "assets/cat"
|
||||||
|
NUM_EPOCHS = 10
|
||||||
|
MODEL_NAME = "cats_model.pth"
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Инициализация данных и модели
|
||||||
|
labels_dict: dict[int, str]
|
||||||
|
dataset: ImageFolder
|
||||||
|
labels_dict, dataset = get_labels(INPUT_DIR, IMG_SIZE)
|
||||||
|
with open("labels_cats.json", "w") as f:
|
||||||
|
f.write(json.dumps(labels_dict))
|
||||||
|
|
||||||
|
model: nn.Module = load_model(MODEL_NAME, labels_dict).to(DEVICE)
|
||||||
|
|
||||||
|
# Подготовка данных
|
||||||
|
train_loader: DataLoader
|
||||||
|
val_loader: DataLoader
|
||||||
|
train_loader, val_loader = get_loaders(dataset)
|
||||||
|
|
||||||
|
# Обучение модели
|
||||||
|
val_acc_history, train_acc_history, val_loss_history, train_loss_history = train(
|
||||||
|
NUM_EPOCHS, MODEL_NAME, model, train_loader, val_loader
|
||||||
|
)
|
||||||
|
|
||||||
|
# Визуализация результатов
|
||||||
|
show(
|
||||||
|
NUM_EPOCHS,
|
||||||
|
val_acc_history,
|
||||||
|
train_acc_history,
|
||||||
|
val_loss_history,
|
||||||
|
train_loss_history,
|
||||||
|
)
|
||||||
167
ml/dogs.py
167
ml/dogs.py
|
|
@ -1,157 +1,28 @@
|
||||||
import os
|
import json
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import torch
|
from PIL import ImageFile
|
||||||
|
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torchvision.datasets import ImageFolder # type: ignore
|
from torchvision.datasets import ImageFolder # type: ignore
|
||||||
from torch.utils.data import Dataset, DataLoader, random_split
|
from torch.utils.data import DataLoader
|
||||||
from torchvision import transforms # type: ignore
|
|
||||||
import torchvision
|
from train import get_labels, load_model, get_loaders, train, show, DEVICE
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
# Настройка устройства для вычислений
|
|
||||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
print(f"Using device: {DEVICE}")
|
print(f"Using device: {DEVICE}")
|
||||||
IMG_SIZE = (200, 200)
|
IMG_SIZE = (180, 180)
|
||||||
INPUT_DIR = "assets/dog"
|
INPUT_DIR = "assets/dog"
|
||||||
NUM_EPOCHS = 90
|
NUM_EPOCHS = 10
|
||||||
|
MODEL_NAME = "dogs_model.pth"
|
||||||
def get_labels(input_dir, img_size):
|
|
||||||
# Преобразования изображений
|
|
||||||
transform = transforms.Compose([
|
|
||||||
transforms.Resize(img_size),
|
|
||||||
transforms.ToTensor(),
|
|
||||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
|
||||||
])
|
|
||||||
dataset = ImageFolder(root=input_dir, transform=transform)
|
|
||||||
|
|
||||||
# Создание labels_dict для соответствия классов и индексов
|
|
||||||
labels_dict = {idx: class_name for idx, class_name in enumerate(dataset.classes)}
|
|
||||||
return labels_dict, dataset
|
|
||||||
|
|
||||||
def get_loaders(dataset: Dataset) -> Tuple[DataLoader, DataLoader]:
|
|
||||||
# Разделение данных на тренировочные и валидационные
|
|
||||||
train_size = int(0.8 * len(dataset))
|
|
||||||
val_size = len(dataset) - train_size
|
|
||||||
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
|
|
||||||
|
|
||||||
# Загрузчики данных
|
|
||||||
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
|
|
||||||
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
|
|
||||||
return train_loader, val_loader
|
|
||||||
|
|
||||||
|
|
||||||
def load_model(model_path: str, device: str = 'cuda') -> nn.Module:
|
|
||||||
if not os.path.isfile(model_path):
|
|
||||||
print("Start new model")
|
|
||||||
model = torchvision.models.resnet50(pretrained=True)
|
|
||||||
model.fc = torch.nn.Linear(model.fc.in_features, len(labels_dict))
|
|
||||||
return model
|
|
||||||
model = torch.load(model_path, map_location=device, weights_only=False)
|
|
||||||
model.eval()
|
|
||||||
return model
|
|
||||||
|
|
||||||
def train(num_epochs: int, model: nn.Module, train_loader: DataLoader, val_loader: DataLoader) -> Tuple[list[float], list[float], list[float], list[float]]:
|
|
||||||
criterion = torch.nn.CrossEntropyLoss()
|
|
||||||
optimizer = torch.optim.Adam(model.fc.parameters(), lr=0.001, weight_decay=0.001)
|
|
||||||
# История метрик
|
|
||||||
train_loss_history = []
|
|
||||||
train_acc_history = []
|
|
||||||
val_loss_history = []
|
|
||||||
val_acc_history = []
|
|
||||||
# Обучение с проверкой и сохранением лучшей модели
|
|
||||||
best_val_loss = float('inf')
|
|
||||||
for epoch in range(num_epochs):
|
|
||||||
model.train()
|
|
||||||
running_loss = 0.0
|
|
||||||
correct = 0
|
|
||||||
total = 0
|
|
||||||
|
|
||||||
# Обучение на тренировочных данных
|
|
||||||
for inputs, labels in train_loader:
|
|
||||||
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
|
|
||||||
|
|
||||||
optimizer.zero_grad()
|
|
||||||
outputs = model(inputs)
|
|
||||||
loss = criterion(outputs, labels)
|
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
running_loss += loss.item()
|
|
||||||
_, predicted = outputs.max(1)
|
|
||||||
total += labels.size(0)
|
|
||||||
correct += predicted.eq(labels).sum().item()
|
|
||||||
|
|
||||||
train_loss = running_loss / len(train_loader)
|
|
||||||
train_acc = 100. * correct / total
|
|
||||||
train_loss_history.append(train_loss)
|
|
||||||
train_acc_history.append(train_acc)
|
|
||||||
print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.2f}%")
|
|
||||||
|
|
||||||
# Оценка на валидационных данных
|
|
||||||
model.eval()
|
|
||||||
val_loss = 0.0
|
|
||||||
correct = 0
|
|
||||||
total = 0
|
|
||||||
with torch.no_grad():
|
|
||||||
for inputs, labels in val_loader:
|
|
||||||
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
|
|
||||||
outputs = model(inputs)
|
|
||||||
loss = criterion(outputs, labels)
|
|
||||||
val_loss += loss.item()
|
|
||||||
_, predicted = outputs.max(1)
|
|
||||||
total += labels.size(0)
|
|
||||||
correct += predicted.eq(labels).sum().item()
|
|
||||||
|
|
||||||
val_loss /= len(val_loader)
|
|
||||||
val_acc = 100. * correct / total
|
|
||||||
val_loss_history.append(val_loss)
|
|
||||||
val_acc_history.append(val_acc)
|
|
||||||
if val_loss < best_val_loss:
|
|
||||||
best_val_loss = val_loss
|
|
||||||
print("save model")
|
|
||||||
torch.save(model, "full_model.pth")
|
|
||||||
print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.2f}%")
|
|
||||||
return val_acc_history, train_acc_history, val_loss_history, train_loss_history
|
|
||||||
|
|
||||||
|
|
||||||
def show(num_epochs: int,
|
|
||||||
val_acc_history: list[float],
|
|
||||||
train_acc_history: list[float],
|
|
||||||
val_loss_history: list[float],
|
|
||||||
train_loss_history: list[float]):
|
|
||||||
|
|
||||||
# Построение графиков
|
|
||||||
epochs = range(1, num_epochs + 1)
|
|
||||||
|
|
||||||
# График точности
|
|
||||||
plt.figure(figsize=(10, 5))
|
|
||||||
plt.plot(epochs, train_acc_history, "bo-", label="Точность на обучении")
|
|
||||||
plt.plot(epochs, val_acc_history, "ro-", label="Точность на валидации")
|
|
||||||
plt.title("Точность на этапах обучения и проверки")
|
|
||||||
plt.xlabel("Эпохи")
|
|
||||||
plt.ylabel("Точность (%)")
|
|
||||||
plt.legend()
|
|
||||||
plt.grid()
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
# График потерь
|
|
||||||
plt.figure(figsize=(10, 5))
|
|
||||||
plt.plot(epochs, train_loss_history, "bo-", label="Потери на обучении")
|
|
||||||
plt.plot(epochs, val_loss_history, "ro-", label="Потери на валидации")
|
|
||||||
plt.title("Потери на этапах обучения и проверки")
|
|
||||||
plt.xlabel("Эпохи")
|
|
||||||
plt.ylabel("Потери")
|
|
||||||
plt.legend()
|
|
||||||
plt.grid()
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Инициализация данных и модели
|
# Инициализация данных и модели
|
||||||
labels_dict: dict[int, str]
|
labels_dict: dict[int, str]
|
||||||
dataset: ImageFolder
|
dataset: ImageFolder
|
||||||
labels_dict, dataset = get_labels(INPUT_DIR, IMG_SIZE)
|
labels_dict, dataset = get_labels(INPUT_DIR, IMG_SIZE)
|
||||||
|
with open("labels_dogs.json", "w") as f:
|
||||||
|
f.write(json.dumps(labels_dict))
|
||||||
|
|
||||||
model: nn.Module = load_model("full_model.pth").to(DEVICE)
|
model: nn.Module = load_model(MODEL_NAME, labels_dict).to(DEVICE)
|
||||||
|
|
||||||
# Подготовка данных
|
# Подготовка данных
|
||||||
train_loader: DataLoader
|
train_loader: DataLoader
|
||||||
|
|
@ -159,7 +30,15 @@ if __name__ == "__main__":
|
||||||
train_loader, val_loader = get_loaders(dataset)
|
train_loader, val_loader = get_loaders(dataset)
|
||||||
|
|
||||||
# Обучение модели
|
# Обучение модели
|
||||||
val_acc_history, train_acc_history, val_loss_history, train_loss_history = train(NUM_EPOCHS, model, train_loader, val_loader)
|
val_acc_history, train_acc_history, val_loss_history, train_loss_history = train(
|
||||||
|
NUM_EPOCHS, MODEL_NAME, model, train_loader, val_loader
|
||||||
|
)
|
||||||
|
|
||||||
# Визуализация результатов
|
# Визуализация результатов
|
||||||
show(NUM_EPOCHS, val_acc_history, train_acc_history, val_loss_history, train_loss_history)
|
show(
|
||||||
|
NUM_EPOCHS,
|
||||||
|
val_acc_history,
|
||||||
|
train_acc_history,
|
||||||
|
val_loss_history,
|
||||||
|
train_loss_history,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,6 @@
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torchvision import transforms # type: ignore
|
from torchvision import transforms # type: ignore
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torchvision import transforms
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
|
@ -13,25 +11,29 @@ with open("labels.json", "r") as f:
|
||||||
labels_dict = json.loads(data_labels)
|
labels_dict = json.loads(data_labels)
|
||||||
|
|
||||||
|
|
||||||
def load_model(model_path,device='cuda'):
|
def load_model(model_path, device="cuda"):
|
||||||
model = torch.load(model_path, map_location=device, weights_only=False)
|
model = torch.load(model_path, map_location=device, weights_only=False)
|
||||||
model.eval()
|
model.eval()
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
# Инициализация
|
# Инициализация
|
||||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
model = load_model('full_model.pth', device=device)
|
model = load_model("full_model.pth", device=device)
|
||||||
|
|
||||||
|
|
||||||
# Преобразования для изображения (адаптируйте под ваш случай)
|
# Преобразования для изображения (адаптируйте под ваш случай)
|
||||||
# Преобразования изображений
|
# Преобразования изображений
|
||||||
def predict_image(image_path, model, device='cuda'):
|
def predict_image(image_path, model, device="cuda"):
|
||||||
img_size = (200, 200)
|
img_size = (180, 180)
|
||||||
preprocess = transforms.Compose([
|
preprocess = transforms.Compose(
|
||||||
|
[
|
||||||
transforms.Resize(img_size),
|
transforms.Resize(img_size),
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
||||||
])
|
]
|
||||||
image = Image.open(image_path).convert('RGB')
|
)
|
||||||
|
image = Image.open(image_path).convert("RGB")
|
||||||
input_tensor = preprocess(image)
|
input_tensor = preprocess(image)
|
||||||
input_batch = input_tensor.unsqueeze(0).to(device) # Добавляем dimension для батча
|
input_batch = input_tensor.unsqueeze(0).to(device) # Добавляем dimension для батча
|
||||||
|
|
||||||
|
|
@ -43,10 +45,11 @@ def predict_image(image_path, model, device='cuda'):
|
||||||
_, predicted_idx = torch.max(probabilities, 0)
|
_, predicted_idx = torch.max(probabilities, 0)
|
||||||
return predicted_idx.item(), probabilities.cpu().numpy()
|
return predicted_idx.item(), probabilities.cpu().numpy()
|
||||||
|
|
||||||
|
|
||||||
# Пример использования
|
# Пример использования
|
||||||
image_path = 'assets/test/photo_2023-04-25_10-02-25.jpg'
|
image_path = "assets/test/photo_2023-04-25_10-02-25.jpg"
|
||||||
predicted_idx, probabilities = predict_image(image_path, model, device)
|
predicted_idx, probabilities = predict_image(image_path, model, device)
|
||||||
|
|
||||||
# Предполагая, что labels_dict - словарь вида {индекс: 'название_класса'}
|
# Предполагая, что labels_dict - словарь вида {индекс: 'название_класса'}
|
||||||
predicted_label = labels_dict[str(predicted_idx)]
|
predicted_label = labels_dict[str(predicted_idx)]
|
||||||
print(f'Predicted class: {predicted_label} (prob: {probabilities[predicted_idx]:.2f})')
|
print(f"Predicted class: {predicted_label} (prob: {probabilities[predicted_idx]:.2f})")
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,156 @@
|
||||||
|
import os
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torchvision.datasets import ImageFolder # type: ignore
|
||||||
|
from torch.utils.data import Dataset, DataLoader, random_split
|
||||||
|
from torchvision import transforms # type: ignore
|
||||||
|
import torchvision
|
||||||
|
from torchvision.models import ResNet50_Weights
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
|
||||||
|
def get_labels(input_dir, img_size):
|
||||||
|
# Преобразования изображений
|
||||||
|
transform = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.Resize(img_size),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
dataset = ImageFolder(root=input_dir, transform=transform)
|
||||||
|
|
||||||
|
# Создание labels_dict для соответствия классов и индексов
|
||||||
|
labels_dict = {idx: class_name for idx, class_name in enumerate(dataset.classes)}
|
||||||
|
return labels_dict, dataset
|
||||||
|
|
||||||
|
|
||||||
|
def get_loaders(dataset: Dataset) -> Tuple[DataLoader, DataLoader]:
|
||||||
|
# Разделение данных на тренировочные и валидационные
|
||||||
|
train_size = int(0.8 * float(len(dataset))) # type: ignore[arg-type]
|
||||||
|
val_size = len(dataset) - train_size # type: ignore[arg-type]
|
||||||
|
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
|
||||||
|
|
||||||
|
# Загрузчики данных
|
||||||
|
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
|
||||||
|
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
|
||||||
|
return train_loader, val_loader
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(model_path: str, labels_dict: dict, device: str = "cuda") -> nn.Module:
|
||||||
|
if not os.path.isfile(model_path):
|
||||||
|
print("Start new model")
|
||||||
|
model = torchvision.models.resnet50(weights=ResNet50_Weights.DEFAULT)
|
||||||
|
model.fc = torch.nn.Linear(model.fc.in_features, len(labels_dict))
|
||||||
|
return model
|
||||||
|
model = torch.load(model_path, map_location=device, weights_only=False)
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def train(
|
||||||
|
num_epochs: int,
|
||||||
|
model_namme: str,
|
||||||
|
model: nn.Module,
|
||||||
|
train_loader: DataLoader,
|
||||||
|
val_loader: DataLoader,
|
||||||
|
) -> Tuple[list[float], list[float], list[float], list[float]]:
|
||||||
|
criterion = torch.nn.CrossEntropyLoss()
|
||||||
|
optimizer = torch.optim.Adam(model.fc.parameters(), lr=0.001, weight_decay=0.001) # type: ignore[union-attr]
|
||||||
|
# История метрик
|
||||||
|
train_loss_history = []
|
||||||
|
train_acc_history = []
|
||||||
|
val_loss_history = []
|
||||||
|
val_acc_history = []
|
||||||
|
# Обучение с проверкой и сохранением лучшей модели
|
||||||
|
best_val_loss = float("inf")
|
||||||
|
for epoch in range(num_epochs):
|
||||||
|
model.train()
|
||||||
|
running_loss = 0.0
|
||||||
|
correct = 0
|
||||||
|
total = 0
|
||||||
|
|
||||||
|
# Обучение на тренировочных данных
|
||||||
|
for inputs, labels in train_loader:
|
||||||
|
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
outputs = model(inputs)
|
||||||
|
loss = criterion(outputs, labels)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
running_loss += loss.item()
|
||||||
|
_, predicted = outputs.max(1)
|
||||||
|
total += labels.size(0)
|
||||||
|
correct += predicted.eq(labels).sum().item()
|
||||||
|
|
||||||
|
train_loss = running_loss / len(train_loader)
|
||||||
|
train_acc = 100.0 * correct / total
|
||||||
|
train_loss_history.append(train_loss)
|
||||||
|
train_acc_history.append(train_acc)
|
||||||
|
print(
|
||||||
|
f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.2f}%"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Оценка на валидационных данных
|
||||||
|
model.eval()
|
||||||
|
val_loss = 0.0
|
||||||
|
correct = 0
|
||||||
|
total = 0
|
||||||
|
with torch.no_grad():
|
||||||
|
for inputs, labels in val_loader:
|
||||||
|
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
|
||||||
|
outputs = model(inputs)
|
||||||
|
loss = criterion(outputs, labels)
|
||||||
|
val_loss += loss.item()
|
||||||
|
_, predicted = outputs.max(1)
|
||||||
|
total += labels.size(0)
|
||||||
|
correct += predicted.eq(labels).sum().item()
|
||||||
|
|
||||||
|
val_loss /= len(val_loader)
|
||||||
|
val_acc = 100.0 * correct / total
|
||||||
|
val_loss_history.append(val_loss)
|
||||||
|
val_acc_history.append(val_acc)
|
||||||
|
if val_loss < best_val_loss:
|
||||||
|
best_val_loss = val_loss
|
||||||
|
print("save model")
|
||||||
|
torch.save(model, model_namme)
|
||||||
|
print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.2f}%")
|
||||||
|
return val_acc_history, train_acc_history, val_loss_history, train_loss_history
|
||||||
|
|
||||||
|
|
||||||
|
def show(
|
||||||
|
num_epochs: int,
|
||||||
|
val_acc_history: list[float],
|
||||||
|
train_acc_history: list[float],
|
||||||
|
val_loss_history: list[float],
|
||||||
|
train_loss_history: list[float],
|
||||||
|
):
|
||||||
|
# Построение графиков
|
||||||
|
epochs = range(1, num_epochs + 1)
|
||||||
|
|
||||||
|
# График точности
|
||||||
|
plt.figure(figsize=(10, 5))
|
||||||
|
plt.plot(epochs, train_acc_history, "bo-", label="Точность на обучении")
|
||||||
|
plt.plot(epochs, val_acc_history, "ro-", label="Точность на валидации")
|
||||||
|
plt.title("Точность на этапах обучения и проверки")
|
||||||
|
plt.xlabel("Эпохи")
|
||||||
|
plt.ylabel("Точность (%)")
|
||||||
|
plt.legend()
|
||||||
|
plt.grid()
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
# График потерь
|
||||||
|
plt.figure(figsize=(10, 5))
|
||||||
|
plt.plot(epochs, train_loss_history, "bo-", label="Потери на обучении")
|
||||||
|
plt.plot(epochs, val_loss_history, "ro-", label="Потери на валидации")
|
||||||
|
plt.title("Потери на этапах обучения и проверки")
|
||||||
|
plt.xlabel("Эпохи")
|
||||||
|
plt.ylabel("Потери")
|
||||||
|
plt.legend()
|
||||||
|
plt.grid()
|
||||||
|
plt.show()
|
||||||
|
|
@ -15,4 +15,5 @@ dependencies = [
|
||||||
"starlite>=1.51.16",
|
"starlite>=1.51.16",
|
||||||
"torch>=2.6.0",
|
"torch>=2.6.0",
|
||||||
"torchvision>=0.21.0",
|
"torchvision>=0.21.0",
|
||||||
|
"types-requests>=2.32.0.20250328",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
125
server/main.py
125
server/main.py
|
|
@ -1,22 +1,30 @@
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from starlite import Starlite, Controller, StaticFilesConfig, get, post, Body, MediaType, RequestEncodingType, Starlite, UploadFile, Template, TemplateConfig
|
from starlite import (
|
||||||
|
Controller,
|
||||||
|
StaticFilesConfig,
|
||||||
|
get,
|
||||||
|
post,
|
||||||
|
Body,
|
||||||
|
MediaType,
|
||||||
|
RequestEncodingType,
|
||||||
|
Starlite,
|
||||||
|
UploadFile,
|
||||||
|
Template,
|
||||||
|
TemplateConfig,
|
||||||
|
)
|
||||||
from starlite.contrib.jinja import JinjaTemplateEngine
|
from starlite.contrib.jinja import JinjaTemplateEngine
|
||||||
import numpy as np
|
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import requests
|
import requests
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torchvision import transforms # type: ignore
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
model_name = "models/beerd_imagenet_02_05_2023.keras"
|
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
|
||||||
test_model_imagenet = keras.models.load_model(model_name)
|
|
||||||
|
|
||||||
model_name = "./models/beerd_25_04_2023.keras"
|
|
||||||
test_model = keras.models.load_model(model_name)
|
|
||||||
|
|
||||||
dict_names = {}
|
dict_names = {}
|
||||||
with open("beerds.json", "r") as f:
|
with open("beerds.json", "r") as f:
|
||||||
|
|
@ -25,7 +33,7 @@ for key in dict_names:
|
||||||
dict_names[key] = dict_names[key].replace("_", " ")
|
dict_names[key] = dict_names[key].replace("_", " ")
|
||||||
|
|
||||||
VK_URL = "https://api.vk.com/method/"
|
VK_URL = "https://api.vk.com/method/"
|
||||||
TOKEN = ""
|
TOKEN = os.getenv("VK_TOKEN")
|
||||||
headers = {"Authorization": f"Bearer {TOKEN}"}
|
headers = {"Authorization": f"Bearer {TOKEN}"}
|
||||||
group_id = 220240483
|
group_id = 220240483
|
||||||
postfix = "?v=5.131"
|
postfix = "?v=5.131"
|
||||||
|
|
@ -36,7 +44,8 @@ def get_images():
|
||||||
global IMAGES
|
global IMAGES
|
||||||
|
|
||||||
r = requests.get(
|
r = requests.get(
|
||||||
f"{VK_URL}photos.getAll{postfix}&access_token={TOKEN}&owner_id=-{group_id}&count=200")
|
f"{VK_URL}photos.getAll{postfix}&access_token={TOKEN}&owner_id=-{group_id}&count=200"
|
||||||
|
)
|
||||||
items = r.json().get("response").get("items")
|
items = r.json().get("response").get("items")
|
||||||
for item in items:
|
for item in items:
|
||||||
for s in item.get("sizes"):
|
for s in item.get("sizes"):
|
||||||
|
|
@ -46,63 +55,61 @@ def get_images():
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(model_path, device="cpu"):
|
||||||
|
model = torch.load(model_path, map_location=device, weights_only=False)
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
get_images()
|
get_images()
|
||||||
|
MODEL = load_model("./models/dogs_model.pth")
|
||||||
|
|
||||||
|
with open("labels.json", "r") as f:
|
||||||
|
data_labels = f.read()
|
||||||
|
labels_dict = json.loads(data_labels)
|
||||||
|
|
||||||
|
|
||||||
|
def predict_image(image, model, device="cuda"):
|
||||||
|
img_size = (180, 180)
|
||||||
|
preprocess = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.Resize(img_size),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
input_tensor = preprocess(image)
|
||||||
|
input_batch = input_tensor.unsqueeze(0).to(device) # Добавляем dimension для батча
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
output = model(input_batch)
|
||||||
|
|
||||||
|
probabilities = F.softmax(output[0], dim=0)
|
||||||
|
|
||||||
|
_, predicted_idx = torch.max(probabilities, 0)
|
||||||
|
return predicted_idx.item(), probabilities.cpu().numpy()
|
||||||
|
|
||||||
|
|
||||||
class BeerdsController(Controller):
|
class BeerdsController(Controller):
|
||||||
path = "/breeds"
|
path = "/breeds"
|
||||||
|
|
||||||
@post("/", media_type=MediaType.TEXT)
|
@post("/", media_type=MediaType.TEXT)
|
||||||
async def beeds(self, data: UploadFile = Body(media_type=RequestEncodingType.MULTI_PART)) -> dict:
|
async def beeds(
|
||||||
|
self, data: UploadFile = Body(media_type=RequestEncodingType.MULTI_PART)
|
||||||
|
) -> dict:
|
||||||
body = await data.read()
|
body = await data.read()
|
||||||
|
|
||||||
img = Image.open(io.BytesIO(body))
|
img_file = Image.open(io.BytesIO(body))
|
||||||
img = img.convert('RGB')
|
predicted_idx, probabilities = predict_image(img_file, MODEL, "cpu")
|
||||||
|
predicted_label = labels_dict[str(predicted_idx)]
|
||||||
|
|
||||||
img_net = img.resize((180, 180, ), Image.BILINEAR)
|
images = [{"name": predicted_label, "url": IMAGES[predicted_label]}]
|
||||||
img_array = img_to_array(img_net)
|
|
||||||
test_loss_image_net = test_model_imagenet.predict(
|
|
||||||
np.expand_dims(img_array, 0))
|
|
||||||
|
|
||||||
img = img.resize((200, 200, ), Image.BILINEAR)
|
|
||||||
img_array = img_to_array(img)
|
|
||||||
test_loss = test_model.predict(np.expand_dims(img_array, 0))
|
|
||||||
|
|
||||||
result = {}
|
|
||||||
for i, val in enumerate(test_loss[0]):
|
|
||||||
if val <= 0.09:
|
|
||||||
continue
|
|
||||||
result[val] = dict_names[str(i)]
|
|
||||||
|
|
||||||
result_net = {}
|
|
||||||
for i, val in enumerate(test_loss_image_net[0]):
|
|
||||||
if val <= 0.09:
|
|
||||||
continue
|
|
||||||
result_net[val] = dict_names[str(i)]
|
|
||||||
items_one = dict(sorted(result.items(), reverse=True))
|
|
||||||
items_two = dict(sorted(result_net.items(), reverse=True))
|
|
||||||
images = []
|
|
||||||
for item in items_one:
|
|
||||||
name = items_one[item].replace("_", " ")
|
|
||||||
if name not in IMAGES:
|
|
||||||
continue
|
|
||||||
images.append({"name": name, "url": IMAGES[name]})
|
|
||||||
for item in items_two:
|
|
||||||
name = items_two[item].replace("_", " ")
|
|
||||||
if name not in IMAGES:
|
|
||||||
continue
|
|
||||||
images.append({"name": name, "url": IMAGES[name]})
|
|
||||||
return {
|
return {
|
||||||
"results": items_one,
|
"results": {probabilities: predicted_label},
|
||||||
"results_net": items_two,
|
|
||||||
"images": images,
|
"images": images,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class BaseController(Controller):
|
class BaseController(Controller):
|
||||||
path = "/"
|
path = "/"
|
||||||
|
|
||||||
|
|
@ -112,7 +119,7 @@ class BaseController(Controller):
|
||||||
|
|
||||||
@get("/sitemap.xml", media_type=MediaType.XML)
|
@get("/sitemap.xml", media_type=MediaType.XML)
|
||||||
async def sitemaps(self) -> bytes:
|
async def sitemaps(self) -> bytes:
|
||||||
return '''<?xml version="1.0" encoding="UTF-8"?>
|
return """<?xml version="1.0" encoding="UTF-8"?>
|
||||||
<urlset
|
<urlset
|
||||||
xmlns="http://www.sitemaps.org/schemas/sitemap/0.9"
|
xmlns="http://www.sitemaps.org/schemas/sitemap/0.9"
|
||||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||||
|
|
@ -128,24 +135,22 @@ class BaseController(Controller):
|
||||||
|
|
||||||
|
|
||||||
</urlset>
|
</urlset>
|
||||||
'''.encode()
|
""".encode()
|
||||||
|
|
||||||
|
|
||||||
@get("/robots.txt", media_type=MediaType.TEXT)
|
@get("/robots.txt", media_type=MediaType.TEXT)
|
||||||
async def robots(self) -> str:
|
async def robots(self) -> str:
|
||||||
return '''
|
return """
|
||||||
User-agent: *
|
User-agent: *
|
||||||
Allow: /
|
Allow: /
|
||||||
|
|
||||||
Sitemap: https://xn-----6kcp3cadbabfh8a0a.xn--p1ai/sitemap.xml
|
Sitemap: https://xn-----6kcp3cadbabfh8a0a.xn--p1ai/sitemap.xml
|
||||||
'''
|
"""
|
||||||
|
|
||||||
|
|
||||||
app = Starlite(
|
app = Starlite(
|
||||||
route_handlers=[BeerdsController, BaseController],
|
route_handlers=[BeerdsController, BaseController],
|
||||||
static_files_config=[
|
static_files_config=[
|
||||||
StaticFilesConfig(directories=["static"], path="/static"),
|
StaticFilesConfig(directories=[Path("static")], path="/static"),
|
||||||
|
|
||||||
],
|
],
|
||||||
template_config=TemplateConfig(
|
template_config=TemplateConfig(
|
||||||
directory=Path("templates"),
|
directory=Path("templates"),
|
||||||
|
|
|
||||||
File diff suppressed because one or more lines are too long
Binary file not shown.
Binary file not shown.
23
uv.lock
23
uv.lock
|
|
@ -22,6 +22,7 @@ dependencies = [
|
||||||
{ name = "starlite" },
|
{ name = "starlite" },
|
||||||
{ name = "torch" },
|
{ name = "torch" },
|
||||||
{ name = "torchvision" },
|
{ name = "torchvision" },
|
||||||
|
{ name = "types-requests" },
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.metadata]
|
[package.metadata]
|
||||||
|
|
@ -36,6 +37,7 @@ requires-dist = [
|
||||||
{ name = "starlite", specifier = ">=1.51.16" },
|
{ name = "starlite", specifier = ">=1.51.16" },
|
||||||
{ name = "torch", specifier = ">=2.6.0" },
|
{ name = "torch", specifier = ">=2.6.0" },
|
||||||
{ name = "torchvision", specifier = ">=0.21.0" },
|
{ name = "torchvision", specifier = ">=0.21.0" },
|
||||||
|
{ name = "types-requests", specifier = ">=2.32.0.20250328" },
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
@ -1215,6 +1217,18 @@ wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/c7/30/37a3384d1e2e9320331baca41e835e90a3767303642c7a80d4510152cbcf/triton-3.2.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e5dfa23ba84541d7c0a531dfce76d8bcd19159d50a4a8b14ad01e91734a5c1b0", size = 253154278 },
|
{ url = "https://files.pythonhosted.org/packages/c7/30/37a3384d1e2e9320331baca41e835e90a3767303642c7a80d4510152cbcf/triton-3.2.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e5dfa23ba84541d7c0a531dfce76d8bcd19159d50a4a8b14ad01e91734a5c1b0", size = 253154278 },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "types-requests"
|
||||||
|
version = "2.32.0.20250328"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "urllib3" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/00/7d/eb174f74e3f5634eaacb38031bbe467dfe2e545bc255e5c90096ec46bc46/types_requests-2.32.0.20250328.tar.gz", hash = "sha256:c9e67228ea103bd811c96984fac36ed2ae8da87a36a633964a21f199d60baf32", size = 22995 }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/cc/15/3700282a9d4ea3b37044264d3e4d1b1f0095a4ebf860a99914fd544e3be3/types_requests-2.32.0.20250328-py3-none-any.whl", hash = "sha256:72ff80f84b15eb3aa7a8e2625fffb6a93f2ad5a0c20215fc1dcfa61117bcb2a2", size = 20663 },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "typing-extensions"
|
name = "typing-extensions"
|
||||||
version = "4.13.2"
|
version = "4.13.2"
|
||||||
|
|
@ -1232,3 +1246,12 @@ sdist = { url = "https://files.pythonhosted.org/packages/95/32/1a225d6164441be76
|
||||||
wheels = [
|
wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/5c/23/c7abc0ca0a1526a0774eca151daeb8de62ec457e77262b66b359c3c7679e/tzdata-2025.2-py2.py3-none-any.whl", hash = "sha256:1a403fada01ff9221ca8044d701868fa132215d84beb92242d9acd2147f667a8", size = 347839 },
|
{ url = "https://files.pythonhosted.org/packages/5c/23/c7abc0ca0a1526a0774eca151daeb8de62ec457e77262b66b359c3c7679e/tzdata-2025.2-py2.py3-none-any.whl", hash = "sha256:1a403fada01ff9221ca8044d701868fa132215d84beb92242d9acd2147f667a8", size = 347839 },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "urllib3"
|
||||||
|
version = "2.4.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/8a/78/16493d9c386d8e60e442a35feac5e00f0913c0f4b7c217c11e8ec2ff53e0/urllib3-2.4.0.tar.gz", hash = "sha256:414bc6535b787febd7567804cc015fee39daab8ad86268f1310a9250697de466", size = 390672 }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/6b/11/cc635220681e93a0183390e26485430ca2c7b5f9d33b15c74c2861cb8091/urllib3-2.4.0-py3-none-any.whl", hash = "sha256:4e16665048960a0900c702d4a66415956a584919c03361cac9f1df5c5dd7e813", size = 128680 },
|
||||||
|
]
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,8 @@ dir = "../assets/dog"
|
||||||
list_labels = [fname for fname in os.listdir(dir)]
|
list_labels = [fname for fname in os.listdir(dir)]
|
||||||
|
|
||||||
r = requests.get(
|
r = requests.get(
|
||||||
f"{VK_URL}photos.getAll{postfix}&access_token={TOKEN}&owner_id=-{group_id}&count=200")
|
f"{VK_URL}photos.getAll{postfix}&access_token={TOKEN}&owner_id=-{group_id}&count=200"
|
||||||
|
)
|
||||||
if "error" in r.json():
|
if "error" in r.json():
|
||||||
print("error", r.json())
|
print("error", r.json())
|
||||||
exit()
|
exit()
|
||||||
|
|
@ -40,23 +41,36 @@ for name in list_labels:
|
||||||
max_index = i
|
max_index = i
|
||||||
image_name = list_data[max_index]
|
image_name = list_data[max_index]
|
||||||
file_stats = os.stat(os.path.join(dir, name, image_name))
|
file_stats = os.stat(os.path.join(dir, name, image_name))
|
||||||
r = requests.post(f"{VK_URL}photos.createAlbum{postfix}", data={
|
r = requests.post(
|
||||||
"title": name.replace("_", " "), "group_id": group_id}, headers=headers)
|
f"{VK_URL}photos.createAlbum{postfix}",
|
||||||
|
data={"title": name.replace("_", " "), "group_id": group_id},
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
if "error" in r.json():
|
if "error" in r.json():
|
||||||
print("error", r.json())
|
print("error", r.json())
|
||||||
break
|
break
|
||||||
album_id = r.json().get("response").get("id")
|
album_id = r.json().get("response").get("id")
|
||||||
r = requests.get(
|
r = requests.get(
|
||||||
f"{VK_URL}photos.getUploadServer{postfix}&album_id={album_id}&access_token={TOKEN}&group_id={group_id}")
|
f"{VK_URL}photos.getUploadServer{postfix}&album_id={album_id}&access_token={TOKEN}&group_id={group_id}"
|
||||||
|
)
|
||||||
url = r.json().get("response").get("upload_url")
|
url = r.json().get("response").get("upload_url")
|
||||||
files = {'file1': open(os.path.join(dir, name, image_name), 'rb')}
|
files = {"file1": open(os.path.join(dir, name, image_name), "rb")}
|
||||||
r = requests.post(url, files=files)
|
r = requests.post(url, files=files)
|
||||||
server = r.json().get("server")
|
server = r.json().get("server")
|
||||||
photos_list = r.json().get("photos_list")
|
photos_list = r.json().get("photos_list")
|
||||||
hash_data = r.json().get("hash")
|
hash_data = r.json().get("hash")
|
||||||
aid = r.json().get("aid")
|
aid = r.json().get("aid")
|
||||||
r = requests.post(f"{VK_URL}photos.save{postfix}&hash={hash_data}", data={"album_id": aid, "server": server,
|
r = requests.post(
|
||||||
"photos_list": photos_list, "caption": name.replace("_", " "), "group_id": group_id}, headers=headers)
|
f"{VK_URL}photos.save{postfix}&hash={hash_data}",
|
||||||
|
data={
|
||||||
|
"album_id": aid,
|
||||||
|
"server": server,
|
||||||
|
"photos_list": photos_list,
|
||||||
|
"caption": name.replace("_", " "),
|
||||||
|
"group_id": group_id,
|
||||||
|
},
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
if "error" in r.json():
|
if "error" in r.json():
|
||||||
print("error", r.json())
|
print("error", r.json())
|
||||||
break
|
break
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue