beerds/ml/dogs_check.py

56 lines
1.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
from torchvision import transforms # type: ignore
import torch.nn.functional as F
from PIL import Image
import json
# Создание labels_dict для соответствия классов и индексов
with open("labels.json", "r") as f:
data_labels = f.read()
labels_dict = json.loads(data_labels)
def load_model(model_path, device="cuda"):
model = torch.load(model_path, map_location=device, weights_only=False)
model.eval()
return model
# Инициализация
device = "cuda" if torch.cuda.is_available() else "cpu"
model = load_model("full_model.pth", device=device)
# Преобразования для изображения (адаптируйте под ваш случай)
# Преобразования изображений
def predict_image(image_path, model, device="cuda"):
img_size = (180, 180)
preprocess = transforms.Compose(
[
transforms.Resize(img_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
image = Image.open(image_path).convert("RGB")
input_tensor = preprocess(image)
input_batch = input_tensor.unsqueeze(0).to(device) # Добавляем dimension для батча
with torch.no_grad():
output = model(input_batch)
probabilities = F.softmax(output[0], dim=0)
_, predicted_idx = torch.max(probabilities, 0)
return predicted_idx.item(), probabilities.cpu().numpy()
# Пример использования
image_path = "assets/test/photo_2023-04-25_10-02-25.jpg"
predicted_idx, probabilities = predict_image(image_path, model, device)
# Предполагая, что labels_dict - словарь вида {индекс: 'название_класса'}
predicted_label = labels_dict[str(predicted_idx)]
print(f"Predicted class: {predicted_label} (prob: {probabilities[predicted_idx]:.2f})")