157 lines
5.8 KiB
Python
157 lines
5.8 KiB
Python
import os
|
||
import matplotlib.pyplot as plt
|
||
import torch
|
||
import torch.nn as nn
|
||
from torchvision.datasets import ImageFolder # type: ignore
|
||
from torch.utils.data import Dataset, DataLoader, random_split
|
||
from torchvision import transforms # type: ignore
|
||
import torchvision
|
||
from torchvision.models import ResNet50_Weights
|
||
from typing import Tuple
|
||
|
||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
|
||
|
||
def get_labels(input_dir, img_size):
|
||
# Преобразования изображений
|
||
transform = transforms.Compose(
|
||
[
|
||
transforms.Resize(img_size),
|
||
transforms.ToTensor(),
|
||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
||
]
|
||
)
|
||
dataset = ImageFolder(root=input_dir, transform=transform)
|
||
|
||
# Создание labels_dict для соответствия классов и индексов
|
||
labels_dict = {idx: class_name for idx, class_name in enumerate(dataset.classes)}
|
||
return labels_dict, dataset
|
||
|
||
|
||
def get_loaders(dataset: Dataset) -> Tuple[DataLoader, DataLoader]:
|
||
# Разделение данных на тренировочные и валидационные
|
||
train_size = int(0.8 * float(len(dataset))) # type: ignore[arg-type]
|
||
val_size = len(dataset) - train_size # type: ignore[arg-type]
|
||
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
|
||
|
||
# Загрузчики данных
|
||
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
|
||
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
|
||
return train_loader, val_loader
|
||
|
||
|
||
def load_model(model_path: str, labels_dict: dict, device: str = "cuda") -> nn.Module:
|
||
if not os.path.isfile(model_path):
|
||
print("Start new model")
|
||
model = torchvision.models.resnet50(weights=ResNet50_Weights.DEFAULT)
|
||
model.fc = torch.nn.Linear(model.fc.in_features, len(labels_dict))
|
||
return model
|
||
model = torch.load(model_path, map_location=device, weights_only=False)
|
||
model.eval()
|
||
return model
|
||
|
||
|
||
def train(
|
||
num_epochs: int,
|
||
model_namme: str,
|
||
model: nn.Module,
|
||
train_loader: DataLoader,
|
||
val_loader: DataLoader,
|
||
) -> Tuple[list[float], list[float], list[float], list[float]]:
|
||
criterion = torch.nn.CrossEntropyLoss()
|
||
optimizer = torch.optim.Adam(model.fc.parameters(), lr=0.001, weight_decay=0.001) # type: ignore[union-attr]
|
||
# История метрик
|
||
train_loss_history = []
|
||
train_acc_history = []
|
||
val_loss_history = []
|
||
val_acc_history = []
|
||
# Обучение с проверкой и сохранением лучшей модели
|
||
best_val_loss = float("inf")
|
||
for epoch in range(num_epochs):
|
||
model.train()
|
||
running_loss = 0.0
|
||
correct = 0
|
||
total = 0
|
||
|
||
# Обучение на тренировочных данных
|
||
for inputs, labels in train_loader:
|
||
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
|
||
|
||
optimizer.zero_grad()
|
||
outputs = model(inputs)
|
||
loss = criterion(outputs, labels)
|
||
loss.backward()
|
||
optimizer.step()
|
||
|
||
running_loss += loss.item()
|
||
_, predicted = outputs.max(1)
|
||
total += labels.size(0)
|
||
correct += predicted.eq(labels).sum().item()
|
||
|
||
train_loss = running_loss / len(train_loader)
|
||
train_acc = 100.0 * correct / total
|
||
train_loss_history.append(train_loss)
|
||
train_acc_history.append(train_acc)
|
||
print(
|
||
f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.2f}%"
|
||
)
|
||
|
||
# Оценка на валидационных данных
|
||
model.eval()
|
||
val_loss = 0.0
|
||
correct = 0
|
||
total = 0
|
||
with torch.no_grad():
|
||
for inputs, labels in val_loader:
|
||
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
|
||
outputs = model(inputs)
|
||
loss = criterion(outputs, labels)
|
||
val_loss += loss.item()
|
||
_, predicted = outputs.max(1)
|
||
total += labels.size(0)
|
||
correct += predicted.eq(labels).sum().item()
|
||
|
||
val_loss /= len(val_loader)
|
||
val_acc = 100.0 * correct / total
|
||
val_loss_history.append(val_loss)
|
||
val_acc_history.append(val_acc)
|
||
if val_loss < best_val_loss:
|
||
best_val_loss = val_loss
|
||
print("save model")
|
||
torch.save(model, model_namme)
|
||
print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.2f}%")
|
||
return val_acc_history, train_acc_history, val_loss_history, train_loss_history
|
||
|
||
|
||
def show(
|
||
num_epochs: int,
|
||
val_acc_history: list[float],
|
||
train_acc_history: list[float],
|
||
val_loss_history: list[float],
|
||
train_loss_history: list[float],
|
||
):
|
||
# Построение графиков
|
||
epochs = range(1, num_epochs + 1)
|
||
|
||
# График точности
|
||
plt.figure(figsize=(10, 5))
|
||
plt.plot(epochs, train_acc_history, "bo-", label="Точность на обучении")
|
||
plt.plot(epochs, val_acc_history, "ro-", label="Точность на валидации")
|
||
plt.title("Точность на этапах обучения и проверки")
|
||
plt.xlabel("Эпохи")
|
||
plt.ylabel("Точность (%)")
|
||
plt.legend()
|
||
plt.grid()
|
||
plt.show()
|
||
|
||
# График потерь
|
||
plt.figure(figsize=(10, 5))
|
||
plt.plot(epochs, train_loss_history, "bo-", label="Потери на обучении")
|
||
plt.plot(epochs, val_loss_history, "ro-", label="Потери на валидации")
|
||
plt.title("Потери на этапах обучения и проверки")
|
||
plt.xlabel("Эпохи")
|
||
plt.ylabel("Потери")
|
||
plt.legend()
|
||
plt.grid()
|
||
plt.show()
|