import os import matplotlib.pyplot as plt import torch import torch.nn as nn from torchvision.datasets import ImageFolder # type: ignore from torch.utils.data import Dataset, DataLoader, random_split from torchvision import transforms # type: ignore import torchvision from torchvision.models import ResNet50_Weights # type: ignore from typing import Tuple DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") def get_labels(input_dir, img_size): # Преобразования изображений transform = transforms.Compose( [ transforms.Resize(img_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ] ) dataset = ImageFolder(root=input_dir, transform=transform) # Создание labels_dict для соответствия классов и индексов labels_dict = {idx: class_name for idx, class_name in enumerate(dataset.classes)} return labels_dict, dataset def get_loaders(dataset: Dataset) -> Tuple[DataLoader, DataLoader]: # Разделение данных на тренировочные и валидационные train_size = int(0.8 * float(len(dataset))) # type: ignore[arg-type] val_size = len(dataset) - train_size # type: ignore[arg-type] train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) # Загрузчики данных train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False) return train_loader, val_loader def load_model(model_path: str, labels_dict: dict, device: str = "cuda") -> nn.Module: if not os.path.isfile(model_path): print("Start new model") model = torchvision.models.resnet50(weights=ResNet50_Weights.DEFAULT) model.fc = nn.Sequential( nn.Dropout(0.5), # Регуляризация torch.nn.Linear(model.fc.in_features, len(labels_dict)) ) return model model = torch.load(model_path, map_location=device, weights_only=False) model.eval() return model def train( num_epochs: int, model_namme: str, model: nn.Module, train_loader: DataLoader, val_loader: DataLoader, ) -> Tuple[list[float], list[float], list[float], list[float]]: criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.fc.parameters(), lr=1e-4) # type: ignore[union-attr] # История метрик train_loss_history = [] train_acc_history = [] val_loss_history = [] val_acc_history = [] # Обучение с проверкой и сохранением лучшей модели best_val_loss = float("inf") for epoch in range(num_epochs): model.train() running_loss = 0.0 correct = 0 total = 0 # Обучение на тренировочных данных for inputs, labels in train_loader: inputs, labels = inputs.to(DEVICE), labels.to(DEVICE) optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() _, predicted = outputs.max(1) total += labels.size(0) correct += predicted.eq(labels).sum().item() train_loss = running_loss / len(train_loader) train_acc = 100.0 * correct / total train_loss_history.append(train_loss) train_acc_history.append(train_acc) print( f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.2f}%" ) # Оценка на валидационных данных model.eval() val_loss = 0.0 correct = 0 total = 0 with torch.no_grad(): for inputs, labels in val_loader: inputs, labels = inputs.to(DEVICE), labels.to(DEVICE) outputs = model(inputs) loss = criterion(outputs, labels) val_loss += loss.item() _, predicted = outputs.max(1) total += labels.size(0) correct += predicted.eq(labels).sum().item() val_loss /= len(val_loader) val_acc = 100.0 * correct / total val_loss_history.append(val_loss) val_acc_history.append(val_acc) if val_loss < best_val_loss: best_val_loss = val_loss print("save model") torch.save(model, model_namme) print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.2f}%") return val_acc_history, train_acc_history, val_loss_history, train_loss_history def show( num_epochs: int, val_acc_history: list[float], train_acc_history: list[float], val_loss_history: list[float], train_loss_history: list[float], ): # Построение графиков epochs = range(1, num_epochs + 1) # График точности plt.figure(figsize=(10, 5)) plt.plot(epochs, train_acc_history, "bo-", label="Точность на обучении") plt.plot(epochs, val_acc_history, "ro-", label="Точность на валидации") plt.title("Точность на этапах обучения и проверки") plt.xlabel("Эпохи") plt.ylabel("Точность (%)") plt.legend() plt.grid() plt.show() # График потерь plt.figure(figsize=(10, 5)) plt.plot(epochs, train_loss_history, "bo-", label="Потери на обучении") plt.plot(epochs, val_loss_history, "ro-", label="Потери на валидации") plt.title("Потери на этапах обучения и проверки") plt.xlabel("Эпохи") plt.ylabel("Потери") plt.legend() plt.grid() plt.show()