beerds/ml/dogs.py

46 lines
1.3 KiB
Python

import json
from PIL import ImageFile
import torch.nn as nn
from torchvision.datasets import ImageFolder # type: ignore
from torch.utils.data import DataLoader
from train import get_labels, load_model, get_loaders, train, show, DEVICE # type: ignore
ImageFile.LOAD_TRUNCATED_IMAGES = True
print(f"Using device: {DEVICE}")
IMG_SIZE = (180, 180)
INPUT_DIR = "assets/dog"
NUM_EPOCHS = 10
MODEL_NAME = "dogs_model.pth"
if __name__ == "__main__":
# Инициализация данных и модели
labels_dict: dict[int, str]
dataset: ImageFolder
labels_dict, dataset = get_labels(INPUT_DIR, IMG_SIZE)
with open("labels_dogs.json", "w") as f:
f.write(json.dumps(labels_dict))
model: nn.Module = load_model(MODEL_NAME, labels_dict).to(DEVICE)
# Подготовка данных
train_loader: DataLoader
val_loader: DataLoader
train_loader, val_loader = get_loaders(dataset)
# Обучение модели
val_acc_history, train_acc_history, val_loss_history, train_loss_history = train(
NUM_EPOCHS, MODEL_NAME, model, train_loader, val_loader
)
# Визуализация результатов
show(
NUM_EPOCHS,
val_acc_history,
train_acc_history,
val_loss_history,
train_loss_history,
)