import json from PIL import ImageFile import torch.nn as nn from torchvision.datasets import ImageFolder # type: ignore from torch.utils.data import DataLoader from train import get_labels, load_model, get_loaders, train, show, DEVICE # type: ignore ImageFile.LOAD_TRUNCATED_IMAGES = True print(f"Using device: {DEVICE}") IMG_SIZE = (224, 224) INPUT_DIR = "assets/cat" NUM_EPOCHS = 50 MODEL_NAME = "cats_model.pth" if __name__ == "__main__": # Инициализация данных и модели labels_dict: dict[int, str] dataset: ImageFolder labels_dict, dataset = get_labels(INPUT_DIR, IMG_SIZE) with open("labels_cats.json", "w") as f: f.write(json.dumps(labels_dict)) model: nn.Module = load_model(MODEL_NAME, labels_dict).to(DEVICE) # Подготовка данных train_loader: DataLoader val_loader: DataLoader train_loader, val_loader = get_loaders(dataset) # Обучение модели val_acc_history, train_acc_history, val_loss_history, train_loss_history = train( NUM_EPOCHS, MODEL_NAME, model, train_loader, val_loader ) # Визуализация результатов show( NUM_EPOCHS, val_acc_history, train_acc_history, val_loss_history, train_loss_history, )