45 lines
1.3 KiB
Python
45 lines
1.3 KiB
Python
import json
|
|
|
|
from PIL import ImageFile
|
|
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
|
import torch.nn as nn
|
|
from torchvision.datasets import ImageFolder # type: ignore
|
|
from torch.utils.data import DataLoader
|
|
|
|
from train import get_labels, load_model, get_loaders, train, show, DEVICE
|
|
|
|
print(f"Using device: {DEVICE}")
|
|
IMG_SIZE = (180, 180)
|
|
INPUT_DIR = "assets/dog"
|
|
NUM_EPOCHS = 10
|
|
MODEL_NAME = "dogs_model.pth"
|
|
|
|
if __name__ == "__main__":
|
|
# Инициализация данных и модели
|
|
labels_dict: dict[int, str]
|
|
dataset: ImageFolder
|
|
labels_dict, dataset = get_labels(INPUT_DIR, IMG_SIZE)
|
|
with open("labels_dogs.json", "w") as f:
|
|
f.write(json.dumps(labels_dict))
|
|
|
|
model: nn.Module = load_model(MODEL_NAME, labels_dict).to(DEVICE)
|
|
|
|
# Подготовка данных
|
|
train_loader: DataLoader
|
|
val_loader: DataLoader
|
|
train_loader, val_loader = get_loaders(dataset)
|
|
|
|
# Обучение модели
|
|
val_acc_history, train_acc_history, val_loss_history, train_loss_history = train(
|
|
NUM_EPOCHS, MODEL_NAME, model, train_loader, val_loader
|
|
)
|
|
|
|
# Визуализация результатов
|
|
show(
|
|
NUM_EPOCHS,
|
|
val_acc_history,
|
|
train_acc_history,
|
|
val_loss_history,
|
|
train_loss_history,
|
|
)
|