56 lines
1.9 KiB
Python
56 lines
1.9 KiB
Python
import torch
|
||
from torchvision import transforms # type: ignore
|
||
import torch.nn.functional as F
|
||
from PIL import Image
|
||
import json
|
||
|
||
|
||
# Создание labels_dict для соответствия классов и индексов
|
||
with open("labels.json", "r") as f:
|
||
data_labels = f.read()
|
||
labels_dict = json.loads(data_labels)
|
||
|
||
|
||
def load_model(model_path, device="cuda"):
|
||
model = torch.load(model_path, map_location=device, weights_only=False)
|
||
model.eval()
|
||
return model
|
||
|
||
|
||
# Инициализация
|
||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||
model = load_model("full_model.pth", device=device)
|
||
|
||
|
||
# Преобразования для изображения (адаптируйте под ваш случай)
|
||
# Преобразования изображений
|
||
def predict_image(image_path, model, device="cuda"):
|
||
img_size = (180, 180)
|
||
preprocess = transforms.Compose(
|
||
[
|
||
transforms.Resize(img_size),
|
||
transforms.ToTensor(),
|
||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
||
]
|
||
)
|
||
image = Image.open(image_path).convert("RGB")
|
||
input_tensor = preprocess(image)
|
||
input_batch = input_tensor.unsqueeze(0).to(device) # Добавляем dimension для батча
|
||
|
||
with torch.no_grad():
|
||
output = model(input_batch)
|
||
|
||
probabilities = F.softmax(output[0], dim=0)
|
||
|
||
_, predicted_idx = torch.max(probabilities, 0)
|
||
return predicted_idx.item(), probabilities.cpu().numpy()
|
||
|
||
|
||
# Пример использования
|
||
image_path = "assets/test/photo_2023-04-25_10-02-25.jpg"
|
||
predicted_idx, probabilities = predict_image(image_path, model, device)
|
||
|
||
# Предполагая, что labels_dict - словарь вида {индекс: 'название_класса'}
|
||
predicted_label = labels_dict[str(predicted_idx)]
|
||
print(f"Predicted class: {predicted_label} (prob: {probabilities[predicted_idx]:.2f})")
|