keras -> pytorch, sanic -> litestar
This commit is contained in:
parent
95fe63ac6b
commit
61669ea702
|
|
@ -1,3 +1,178 @@
|
||||||
assets/*
|
assets/*
|
||||||
*.jpg
|
*.jpg
|
||||||
beerds.json
|
beerds.json
|
||||||
|
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
|
# .python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# UV
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
#uv.lock
|
||||||
|
|
||||||
|
# poetry
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||||
|
#poetry.lock
|
||||||
|
|
||||||
|
# pdm
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||||
|
#pdm.lock
|
||||||
|
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||||
|
# in version control.
|
||||||
|
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
||||||
|
.pdm.toml
|
||||||
|
.pdm-python
|
||||||
|
.pdm-build/
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
||||||
|
|
||||||
|
# PyCharm
|
||||||
|
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||||
|
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||||
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
|
#.idea/
|
||||||
|
|
||||||
|
# Ruff stuff:
|
||||||
|
.ruff_cache/
|
||||||
|
|
||||||
|
# PyPI configuration file
|
||||||
|
.pypirc
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
3.11
|
||||||
|
|
@ -0,0 +1,12 @@
|
||||||
|
api:
|
||||||
|
uv run granian --interface asgi server.main:app
|
||||||
|
|
||||||
|
runml:
|
||||||
|
uv run ml/beerds.py
|
||||||
|
|
||||||
|
format:
|
||||||
|
uv run ruff format app
|
||||||
|
|
||||||
|
lint:
|
||||||
|
uv run mypy ./ --explicit-package-bases;
|
||||||
|
ruff check --fix
|
||||||
|
|
@ -0,0 +1,5 @@
|
||||||
|
Нужно установить драйвер Nvindia + Cuda
|
||||||
|
|
||||||
|
```
|
||||||
|
sudo apt install nvidia-cuda-toolkit
|
||||||
|
```
|
||||||
Binary file not shown.
Binary file not shown.
107
beerds.py
107
beerds.py
|
|
@ -1,107 +0,0 @@
|
||||||
|
|
||||||
import os
|
|
||||||
import random
|
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import numpy as np
|
|
||||||
import tensorflow as tf
|
|
||||||
from tensorflow import keras
|
|
||||||
from tensorflow.keras import layers
|
|
||||||
from tensorflow.keras.utils import image_dataset_from_directory, split_dataset
|
|
||||||
|
|
||||||
|
|
||||||
img_size = (200, 200)
|
|
||||||
|
|
||||||
|
|
||||||
# обогащение выборки
|
|
||||||
data_augmentation = keras.Sequential(
|
|
||||||
[
|
|
||||||
layers.RandomFlip("horizontal"),
|
|
||||||
layers.RandomRotation(0.1),
|
|
||||||
layers.RandomZoom(0.2),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
input_dir = "assets/dog"
|
|
||||||
|
|
||||||
labels_dict = {}
|
|
||||||
for fname in os.listdir(input_dir):
|
|
||||||
if fname in labels_dict:
|
|
||||||
continue
|
|
||||||
labels_dict[fname] = len(labels_dict)
|
|
||||||
|
|
||||||
model_name = "beerd_25_04_2023.keras"
|
|
||||||
train_dataset, val_ds = image_dataset_from_directory(
|
|
||||||
input_dir,
|
|
||||||
labels="inferred",
|
|
||||||
label_mode="categorical",
|
|
||||||
class_names=None,
|
|
||||||
color_mode="rgb",
|
|
||||||
batch_size=32,
|
|
||||||
seed=12,
|
|
||||||
image_size=img_size,
|
|
||||||
shuffle=True,
|
|
||||||
validation_split=0.1,
|
|
||||||
subset="both",
|
|
||||||
interpolation="bilinear",
|
|
||||||
follow_links=False,
|
|
||||||
crop_to_aspect_ratio=False
|
|
||||||
)
|
|
||||||
|
|
||||||
validation_dataset, test_dataset = split_dataset(val_ds, left_size=0.8)
|
|
||||||
|
|
||||||
inputs = keras.Input(shape=img_size + (3,))
|
|
||||||
x = data_augmentation(inputs)
|
|
||||||
x = layers.Rescaling(1./255)(x)
|
|
||||||
x = layers.Conv2D(filters=32, kernel_size=5, use_bias=False)(x)
|
|
||||||
for size in [32, 64, 128, 256, 512, 1024]:
|
|
||||||
residual = x
|
|
||||||
x = layers.BatchNormalization()(x)
|
|
||||||
x = layers.Activation("relu")(x)
|
|
||||||
x = layers.SeparableConv2D(size, 3, padding="same", use_bias=False)(x)
|
|
||||||
x = layers.BatchNormalization()(x)
|
|
||||||
x = layers.Activation("relu")(x)
|
|
||||||
x = layers.SeparableConv2D(size, 3, padding="same", use_bias=False)(x)
|
|
||||||
x = layers.MaxPooling2D(3, strides=2, padding="same")(x)
|
|
||||||
residual = layers.Conv2D(
|
|
||||||
size, 1, strides=2, padding="same", use_bias=False)(residual)
|
|
||||||
x = layers.add([x, residual])
|
|
||||||
|
|
||||||
x = layers.GlobalAveragePooling2D()(x)
|
|
||||||
x = layers.Dropout(0.5)(x)
|
|
||||||
outputs = layers.Dense(len(labels_dict), activation="softmax")(x)
|
|
||||||
model = keras.Model(inputs, outputs)
|
|
||||||
|
|
||||||
model.compile(optimizer="rmsprop",
|
|
||||||
loss="categorical_crossentropy", metrics=['accuracy'])
|
|
||||||
callbacks = [
|
|
||||||
keras.callbacks.ModelCheckpoint(model_name,
|
|
||||||
save_best_only=True)
|
|
||||||
]
|
|
||||||
history = model.fit(train_dataset,
|
|
||||||
epochs=200,
|
|
||||||
callbacks=callbacks,
|
|
||||||
validation_data=validation_dataset,)
|
|
||||||
|
|
||||||
epochs = range(1, len(history.history["loss"]) + 1)
|
|
||||||
loss = history.history["loss"]
|
|
||||||
val_loss = history.history["val_loss"]
|
|
||||||
acc = history.history["accuracy"]
|
|
||||||
val_acc = history.history["val_accuracy"]
|
|
||||||
plt.plot(epochs, acc, "bo", label="Точность на этапе обучения")
|
|
||||||
plt.plot(epochs, val_acc, "b", label="Точность на этапе проверки")
|
|
||||||
plt.title("Точность на этапах обучения и проверки")
|
|
||||||
plt.legend()
|
|
||||||
plt.figure()
|
|
||||||
plt.plot(epochs, loss, "bo", label="Потери на этапе обучения")
|
|
||||||
plt.plot(epochs, val_loss, "b", label="Потери на этапе проверки")
|
|
||||||
plt.title("Потери на этапах обучения и проверки")
|
|
||||||
plt.legend()
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
|
|
||||||
test_model = keras.models.load_model(model_name)
|
|
||||||
test_loss, test_acc = test_model.evaluate(test_dataset)
|
|
||||||
print(f"Test accuracy: {test_acc:.3f}")
|
|
||||||
|
|
@ -1,128 +0,0 @@
|
||||||
|
|
||||||
import os
|
|
||||||
import random
|
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import numpy as np
|
|
||||||
import tensorflow as tf
|
|
||||||
from tensorflow import keras
|
|
||||||
from tensorflow.keras import layers
|
|
||||||
from tensorflow.keras.utils import image_dataset_from_directory, split_dataset
|
|
||||||
import keras_tuner
|
|
||||||
|
|
||||||
img_size = (180, 180)
|
|
||||||
|
|
||||||
conv_base = keras.applications.resnet.ResNet152(
|
|
||||||
weights="imagenet",
|
|
||||||
include_top=False,
|
|
||||||
input_shape=(180, 180, 3))
|
|
||||||
conv_base.trainable = False
|
|
||||||
# conv_base.trainable = True
|
|
||||||
# for layer in conv_base.layers[:-4]:
|
|
||||||
# layer.trainable = False
|
|
||||||
|
|
||||||
|
|
||||||
# обогащение выборки
|
|
||||||
data_augmentation = keras.Sequential(
|
|
||||||
[
|
|
||||||
layers.RandomFlip("horizontal"),
|
|
||||||
layers.RandomRotation(0.1),
|
|
||||||
layers.RandomZoom(0.2),
|
|
||||||
])
|
|
||||||
|
|
||||||
|
|
||||||
input_dir = "assets/dog"
|
|
||||||
|
|
||||||
labels_dict = {}
|
|
||||||
for fname in os.listdir(input_dir):
|
|
||||||
if fname in labels_dict:
|
|
||||||
continue
|
|
||||||
labels_dict[fname] = len(labels_dict)
|
|
||||||
|
|
||||||
model_name = "beerd_imagenet_02_05_2023.keras"
|
|
||||||
model_dir = "beerd_imagenet"
|
|
||||||
train_dataset, val_ds = image_dataset_from_directory(
|
|
||||||
input_dir,
|
|
||||||
labels="inferred",
|
|
||||||
label_mode="categorical",
|
|
||||||
class_names=None,
|
|
||||||
color_mode="rgb",
|
|
||||||
batch_size=32,
|
|
||||||
seed=12,
|
|
||||||
image_size=img_size,
|
|
||||||
shuffle=True,
|
|
||||||
validation_split=0.1,
|
|
||||||
subset="both",
|
|
||||||
interpolation="bilinear",
|
|
||||||
follow_links=False,
|
|
||||||
crop_to_aspect_ratio=False
|
|
||||||
)
|
|
||||||
validation_dataset, test_dataset = split_dataset(val_ds, left_size=0.8)
|
|
||||||
|
|
||||||
def build_model(hp):
|
|
||||||
inputs = keras.Input(shape=(180, 180, 3))
|
|
||||||
x = data_augmentation(inputs)
|
|
||||||
x = keras.applications.resnet.preprocess_input(x)
|
|
||||||
x = conv_base(x)
|
|
||||||
x = layers.Flatten()(x)
|
|
||||||
units = hp.Int(name="units", min_value=1536, max_value=2048, step=512)
|
|
||||||
x = layers.Dense(units, activation="relu")(x)
|
|
||||||
x = layers.Dropout(0.5)(x)
|
|
||||||
outputs = layers.Dense(len(labels_dict), activation="softmax")(x)
|
|
||||||
model = keras.Model(inputs, outputs)
|
|
||||||
optimizer = hp.Choice(name="optimizer", values=["rmsprop", "adam"])
|
|
||||||
model.compile(optimizer=optimizer,
|
|
||||||
loss="categorical_crossentropy", metrics=['accuracy'])
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def build_model_new():
|
|
||||||
inputs = keras.Input(shape=(180, 180, 3))
|
|
||||||
x = data_augmentation(inputs)
|
|
||||||
x = keras.applications.resnet.preprocess_input(x)
|
|
||||||
x = conv_base(x)
|
|
||||||
x = layers.Flatten()(x)
|
|
||||||
x = layers.Dropout(0.5)(x)
|
|
||||||
outputs = layers.Dense(len(labels_dict), activation="softmax")(x)
|
|
||||||
model = keras.Model(inputs, outputs)
|
|
||||||
model.compile(optimizer="adam",
|
|
||||||
loss="categorical_crossentropy", metrics=['accuracy'])
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# tuner = keras_tuner.BayesianOptimization(
|
|
||||||
# build_model,
|
|
||||||
# objective='val_accuracy',
|
|
||||||
# max_trials=100,
|
|
||||||
# executions_per_trial=2,
|
|
||||||
# directory=model_dir,
|
|
||||||
# overwrite=True,)
|
|
||||||
|
|
||||||
# callbacks = [
|
|
||||||
# keras.callbacks.EarlyStopping(monitor="val_loss", patience=5)
|
|
||||||
# ]
|
|
||||||
callbacks = [
|
|
||||||
keras.callbacks.ModelCheckpoint(model_name,
|
|
||||||
save_best_only=True, monitor="val_accuracy"),
|
|
||||||
keras.callbacks.EarlyStopping(monitor="val_accuracy", patience=5)
|
|
||||||
]
|
|
||||||
|
|
||||||
# tuner.search(train_dataset,
|
|
||||||
# epochs=100,
|
|
||||||
# callbacks=callbacks,
|
|
||||||
# validation_data=validation_dataset,)
|
|
||||||
|
|
||||||
# best_models = tuner.get_best_models(1)
|
|
||||||
# best_models = keras.models.load_model(model_name)
|
|
||||||
best_models = build_model_new()
|
|
||||||
best_models.fit(train_dataset,
|
|
||||||
epochs=30,
|
|
||||||
callbacks=callbacks,
|
|
||||||
validation_data=validation_dataset)
|
|
||||||
|
|
||||||
|
|
||||||
test_model = keras.models.load_model(model_name)
|
|
||||||
test_loss, test_acc = test_model.evaluate(test_dataset)
|
|
||||||
print(f"Test accuracy: {test_acc:.3f}")
|
|
||||||
|
|
||||||
|
|
@ -1,44 +0,0 @@
|
||||||
|
|
||||||
import os
|
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
|
|
||||||
import json
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import numpy as np
|
|
||||||
import tensorflow as tf
|
|
||||||
from tensorflow import keras
|
|
||||||
from tensorflow.keras import layers
|
|
||||||
from tensorflow.keras.utils import load_img, img_to_array
|
|
||||||
from tensorflow.keras.utils import image_dataset_from_directory
|
|
||||||
|
|
||||||
|
|
||||||
# model_name = "beerd_25_04_2023.keras"
|
|
||||||
model_name = "beerd_imagenet_25_04_2023.keras"
|
|
||||||
img = load_img("photo_2023-04-25_10-02-25.jpg", color_mode="rgb")
|
|
||||||
img = tf.image.resize(img, (180, 180, ), "bilinear")
|
|
||||||
img_array = img_to_array(img)
|
|
||||||
|
|
||||||
test_model = keras.models.load_model(model_name)
|
|
||||||
test_loss = test_model.predict(np.expand_dims(img_array, 0))
|
|
||||||
|
|
||||||
|
|
||||||
list_labels = [fname for fname in os.listdir("assets/dog")]
|
|
||||||
list_labels.sort()
|
|
||||||
dict_names = {}
|
|
||||||
for i, label in enumerate(list_labels):
|
|
||||||
dict_names[i] = label
|
|
||||||
|
|
||||||
with open("beerds.json", "w") as f:
|
|
||||||
f.write(json.dumps(dict_names))
|
|
||||||
|
|
||||||
max_val = 0
|
|
||||||
max_num = 0
|
|
||||||
for i, val in enumerate(test_loss[0]):
|
|
||||||
if val < max_val:
|
|
||||||
continue
|
|
||||||
max_val = val
|
|
||||||
max_num = i
|
|
||||||
|
|
||||||
print("-----------------------")
|
|
||||||
print(list_labels)
|
|
||||||
print(test_loss)
|
|
||||||
print(max_num, max_val, dict_names[max_num])
|
|
||||||
|
|
@ -0,0 +1,165 @@
|
||||||
|
import os
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torchvision.datasets import ImageFolder # type: ignore
|
||||||
|
from torch.utils.data import Dataset, DataLoader, random_split
|
||||||
|
from torchvision import transforms # type: ignore
|
||||||
|
import torchvision
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
# Настройка устройства для вычислений
|
||||||
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
print(f"Using device: {DEVICE}")
|
||||||
|
IMG_SIZE = (200, 200)
|
||||||
|
INPUT_DIR = "assets/dog"
|
||||||
|
NUM_EPOCHS = 90
|
||||||
|
|
||||||
|
def get_labels(input_dir, img_size):
|
||||||
|
# Преобразования изображений
|
||||||
|
transform = transforms.Compose([
|
||||||
|
transforms.Resize(img_size),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||||
|
])
|
||||||
|
dataset = ImageFolder(root=input_dir, transform=transform)
|
||||||
|
|
||||||
|
# Создание labels_dict для соответствия классов и индексов
|
||||||
|
labels_dict = {idx: class_name for idx, class_name in enumerate(dataset.classes)}
|
||||||
|
return labels_dict, dataset
|
||||||
|
|
||||||
|
def get_loaders(dataset: Dataset) -> Tuple[DataLoader, DataLoader]:
|
||||||
|
# Разделение данных на тренировочные и валидационные
|
||||||
|
train_size = int(0.8 * len(dataset))
|
||||||
|
val_size = len(dataset) - train_size
|
||||||
|
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
|
||||||
|
|
||||||
|
# Загрузчики данных
|
||||||
|
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
|
||||||
|
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
|
||||||
|
return train_loader, val_loader
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(model_path: str, device: str = 'cuda') -> nn.Module:
|
||||||
|
if not os.path.isfile(model_path):
|
||||||
|
print("Start new model")
|
||||||
|
model = torchvision.models.resnet50(pretrained=True)
|
||||||
|
model.fc = torch.nn.Linear(model.fc.in_features, len(labels_dict))
|
||||||
|
return model
|
||||||
|
model = torch.load(model_path, map_location=device, weights_only=False)
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
|
def train(num_epochs: int, model: nn.Module, train_loader: DataLoader, val_loader: DataLoader) -> Tuple[list[float], list[float], list[float], list[float]]:
|
||||||
|
criterion = torch.nn.CrossEntropyLoss()
|
||||||
|
optimizer = torch.optim.Adam(model.fc.parameters(), lr=0.001, weight_decay=0.001)
|
||||||
|
# История метрик
|
||||||
|
train_loss_history = []
|
||||||
|
train_acc_history = []
|
||||||
|
val_loss_history = []
|
||||||
|
val_acc_history = []
|
||||||
|
# Обучение с проверкой и сохранением лучшей модели
|
||||||
|
best_val_loss = float('inf')
|
||||||
|
for epoch in range(num_epochs):
|
||||||
|
model.train()
|
||||||
|
running_loss = 0.0
|
||||||
|
correct = 0
|
||||||
|
total = 0
|
||||||
|
|
||||||
|
# Обучение на тренировочных данных
|
||||||
|
for inputs, labels in train_loader:
|
||||||
|
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
outputs = model(inputs)
|
||||||
|
loss = criterion(outputs, labels)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
running_loss += loss.item()
|
||||||
|
_, predicted = outputs.max(1)
|
||||||
|
total += labels.size(0)
|
||||||
|
correct += predicted.eq(labels).sum().item()
|
||||||
|
|
||||||
|
train_loss = running_loss / len(train_loader)
|
||||||
|
train_acc = 100. * correct / total
|
||||||
|
train_loss_history.append(train_loss)
|
||||||
|
train_acc_history.append(train_acc)
|
||||||
|
print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.2f}%")
|
||||||
|
|
||||||
|
# Оценка на валидационных данных
|
||||||
|
model.eval()
|
||||||
|
val_loss = 0.0
|
||||||
|
correct = 0
|
||||||
|
total = 0
|
||||||
|
with torch.no_grad():
|
||||||
|
for inputs, labels in val_loader:
|
||||||
|
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
|
||||||
|
outputs = model(inputs)
|
||||||
|
loss = criterion(outputs, labels)
|
||||||
|
val_loss += loss.item()
|
||||||
|
_, predicted = outputs.max(1)
|
||||||
|
total += labels.size(0)
|
||||||
|
correct += predicted.eq(labels).sum().item()
|
||||||
|
|
||||||
|
val_loss /= len(val_loader)
|
||||||
|
val_acc = 100. * correct / total
|
||||||
|
val_loss_history.append(val_loss)
|
||||||
|
val_acc_history.append(val_acc)
|
||||||
|
if val_loss < best_val_loss:
|
||||||
|
best_val_loss = val_loss
|
||||||
|
print("save model")
|
||||||
|
torch.save(model, "full_model.pth")
|
||||||
|
print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.2f}%")
|
||||||
|
return val_acc_history, train_acc_history, val_loss_history, train_loss_history
|
||||||
|
|
||||||
|
|
||||||
|
def show(num_epochs: int,
|
||||||
|
val_acc_history: list[float],
|
||||||
|
train_acc_history: list[float],
|
||||||
|
val_loss_history: list[float],
|
||||||
|
train_loss_history: list[float]):
|
||||||
|
|
||||||
|
# Построение графиков
|
||||||
|
epochs = range(1, num_epochs + 1)
|
||||||
|
|
||||||
|
# График точности
|
||||||
|
plt.figure(figsize=(10, 5))
|
||||||
|
plt.plot(epochs, train_acc_history, "bo-", label="Точность на обучении")
|
||||||
|
plt.plot(epochs, val_acc_history, "ro-", label="Точность на валидации")
|
||||||
|
plt.title("Точность на этапах обучения и проверки")
|
||||||
|
plt.xlabel("Эпохи")
|
||||||
|
plt.ylabel("Точность (%)")
|
||||||
|
plt.legend()
|
||||||
|
plt.grid()
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
# График потерь
|
||||||
|
plt.figure(figsize=(10, 5))
|
||||||
|
plt.plot(epochs, train_loss_history, "bo-", label="Потери на обучении")
|
||||||
|
plt.plot(epochs, val_loss_history, "ro-", label="Потери на валидации")
|
||||||
|
plt.title("Потери на этапах обучения и проверки")
|
||||||
|
plt.xlabel("Эпохи")
|
||||||
|
plt.ylabel("Потери")
|
||||||
|
plt.legend()
|
||||||
|
plt.grid()
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Инициализация данных и модели
|
||||||
|
labels_dict: dict[int, str]
|
||||||
|
dataset: ImageFolder
|
||||||
|
labels_dict, dataset = get_labels(INPUT_DIR, IMG_SIZE)
|
||||||
|
|
||||||
|
model: nn.Module = load_model("full_model.pth").to(DEVICE)
|
||||||
|
|
||||||
|
# Подготовка данных
|
||||||
|
train_loader: DataLoader
|
||||||
|
val_loader: DataLoader
|
||||||
|
train_loader, val_loader = get_loaders(dataset)
|
||||||
|
|
||||||
|
# Обучение модели
|
||||||
|
val_acc_history, train_acc_history, val_loss_history, train_loss_history = train(NUM_EPOCHS, model, train_loader, val_loader)
|
||||||
|
|
||||||
|
# Визуализация результатов
|
||||||
|
show(NUM_EPOCHS, val_acc_history, train_acc_history, val_loss_history, train_loss_history)
|
||||||
|
|
@ -0,0 +1,52 @@
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torchvision import transforms # type: ignore
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torchvision import transforms
|
||||||
|
from PIL import Image
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
# Создание labels_dict для соответствия классов и индексов
|
||||||
|
with open("labels.json", "r") as f:
|
||||||
|
data_labels = f.read()
|
||||||
|
labels_dict = json.loads(data_labels)
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(model_path,device='cuda'):
|
||||||
|
model = torch.load(model_path, map_location=device, weights_only=False)
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
|
# Инициализация
|
||||||
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
model = load_model('full_model.pth', device=device)
|
||||||
|
|
||||||
|
# Преобразования для изображения (адаптируйте под ваш случай)
|
||||||
|
# Преобразования изображений
|
||||||
|
def predict_image(image_path, model, device='cuda'):
|
||||||
|
img_size = (200, 200)
|
||||||
|
preprocess = transforms.Compose([
|
||||||
|
transforms.Resize(img_size),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||||
|
])
|
||||||
|
image = Image.open(image_path).convert('RGB')
|
||||||
|
input_tensor = preprocess(image)
|
||||||
|
input_batch = input_tensor.unsqueeze(0).to(device) # Добавляем dimension для батча
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
output = model(input_batch)
|
||||||
|
|
||||||
|
probabilities = F.softmax(output[0], dim=0)
|
||||||
|
|
||||||
|
_, predicted_idx = torch.max(probabilities, 0)
|
||||||
|
return predicted_idx.item(), probabilities.cpu().numpy()
|
||||||
|
|
||||||
|
# Пример использования
|
||||||
|
image_path = 'assets/test/photo_2023-04-25_10-02-25.jpg'
|
||||||
|
predicted_idx, probabilities = predict_image(image_path, model, device)
|
||||||
|
|
||||||
|
# Предполагая, что labels_dict - словарь вида {индекс: 'название_класса'}
|
||||||
|
predicted_label = labels_dict[str(predicted_idx)]
|
||||||
|
print(f'Predicted class: {predicted_label} (prob: {probabilities[predicted_idx]:.2f})')
|
||||||
|
|
@ -0,0 +1,18 @@
|
||||||
|
[project]
|
||||||
|
name = "ai"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Add your description here"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.11"
|
||||||
|
dependencies = [
|
||||||
|
"granian>=2.2.4",
|
||||||
|
"jinja2>=3.1.6",
|
||||||
|
"matplotlib>=3.10.1",
|
||||||
|
"mypy>=1.15.0",
|
||||||
|
"numpy==1.23.5",
|
||||||
|
"pyqt5>=5.15.11",
|
||||||
|
"ruff>=0.11.5",
|
||||||
|
"starlite>=1.51.16",
|
||||||
|
"torch>=2.6.0",
|
||||||
|
"torchvision>=0.21.0",
|
||||||
|
]
|
||||||
139
server/main.py
139
server/main.py
|
|
@ -1,11 +1,10 @@
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from sanic import Sanic
|
from starlite import Starlite, Controller, StaticFilesConfig, get, post, Body, MediaType, RequestEncodingType, Starlite, UploadFile, Template, TemplateConfig
|
||||||
from sanic.response import json as json_answer, text
|
from starlite.contrib.jinja import JinjaTemplateEngine
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tensorflow import keras
|
|
||||||
from tensorflow.keras.utils import img_to_array
|
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
|
|
@ -13,11 +12,10 @@ import requests
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
|
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
|
||||||
|
|
||||||
|
|
||||||
app = Sanic("Ai")
|
model_name = "models/beerd_imagenet_02_05_2023.keras"
|
||||||
model_name = "../beerd_imagenet_02_05_2023.keras"
|
|
||||||
test_model_imagenet = keras.models.load_model(model_name)
|
test_model_imagenet = keras.models.load_model(model_name)
|
||||||
|
|
||||||
model_name = "../beerd_25_04_2023.keras"
|
model_name = "./models/beerd_25_04_2023.keras"
|
||||||
test_model = keras.models.load_model(model_name)
|
test_model = keras.models.load_model(model_name)
|
||||||
|
|
||||||
dict_names = {}
|
dict_names = {}
|
||||||
|
|
@ -25,11 +23,9 @@ with open("beerds.json", "r") as f:
|
||||||
dict_names = json.loads(f.read())
|
dict_names = json.loads(f.read())
|
||||||
for key in dict_names:
|
for key in dict_names:
|
||||||
dict_names[key] = dict_names[key].replace("_", " ")
|
dict_names[key] = dict_names[key].replace("_", " ")
|
||||||
app.static("/", "index.html", name="main")
|
|
||||||
app.static("/static/", "static/", name="static")
|
|
||||||
|
|
||||||
VK_URL = "https://api.vk.com/method/"
|
VK_URL = "https://api.vk.com/method/"
|
||||||
TOKEN = "vk1.a.2VJFQn9oTIqfpVNcgk7OvxXU8TZPomCH4biRvZEZp8-tQTi8IdKajlXCY5vJbLFjjPGrRWpsM8wbG1mek2pVpktwqi1MGAFJQfQafg68buH7YiE3GtClgWhZNuDUX5PwQuANLRVh6Ao-DcN0Z-72AmWmsIKhf9A4yuE8q3O6Asn_miGvO9gUY_JpctKEVtAYIEhbJtQK7hxW8qpud8J5Vg"
|
TOKEN = ""
|
||||||
headers = {"Authorization": f"Bearer {TOKEN}"}
|
headers = {"Authorization": f"Bearer {TOKEN}"}
|
||||||
group_id = 220240483
|
group_id = 220240483
|
||||||
postfix = "?v=5.131"
|
postfix = "?v=5.131"
|
||||||
|
|
@ -53,54 +49,70 @@ def get_images():
|
||||||
get_images()
|
get_images()
|
||||||
|
|
||||||
|
|
||||||
@app.post("/beeds")
|
class BeerdsController(Controller):
|
||||||
async def beeds(request):
|
path = "/breeds"
|
||||||
body = request.files.get("f").body
|
|
||||||
|
|
||||||
img = Image.open(io.BytesIO(body))
|
@post("/", media_type=MediaType.TEXT)
|
||||||
img = img.convert('RGB')
|
async def beeds(self, data: UploadFile = Body(media_type=RequestEncodingType.MULTI_PART)) -> dict:
|
||||||
|
body = await data.read()
|
||||||
|
|
||||||
img_net = img.resize((180, 180, ), Image.BILINEAR)
|
img = Image.open(io.BytesIO(body))
|
||||||
img_array = img_to_array(img_net)
|
img = img.convert('RGB')
|
||||||
test_loss_image_net = test_model_imagenet.predict(
|
|
||||||
np.expand_dims(img_array, 0))
|
|
||||||
|
|
||||||
img = img.resize((200, 200, ), Image.BILINEAR)
|
img_net = img.resize((180, 180, ), Image.BILINEAR)
|
||||||
img_array = img_to_array(img)
|
img_array = img_to_array(img_net)
|
||||||
test_loss = test_model.predict(np.expand_dims(img_array, 0))
|
test_loss_image_net = test_model_imagenet.predict(
|
||||||
|
np.expand_dims(img_array, 0))
|
||||||
|
|
||||||
result = {}
|
img = img.resize((200, 200, ), Image.BILINEAR)
|
||||||
for i, val in enumerate(test_loss[0]):
|
img_array = img_to_array(img)
|
||||||
if val <= 0.09:
|
test_loss = test_model.predict(np.expand_dims(img_array, 0))
|
||||||
continue
|
|
||||||
result[val] = dict_names[str(i)]
|
|
||||||
|
|
||||||
result_net = {}
|
result = {}
|
||||||
for i, val in enumerate(test_loss_image_net[0]):
|
for i, val in enumerate(test_loss[0]):
|
||||||
if val <= 0.09:
|
if val <= 0.09:
|
||||||
continue
|
continue
|
||||||
result_net[val] = dict_names[str(i)]
|
result[val] = dict_names[str(i)]
|
||||||
items_one = dict(sorted(result.items(), reverse=True))
|
|
||||||
items_two = dict(sorted(result_net.items(), reverse=True))
|
result_net = {}
|
||||||
images = []
|
for i, val in enumerate(test_loss_image_net[0]):
|
||||||
for item in items_one:
|
if val <= 0.09:
|
||||||
name = items_one[item].replace("_", " ")
|
continue
|
||||||
if name not in IMAGES:
|
result_net[val] = dict_names[str(i)]
|
||||||
continue
|
items_one = dict(sorted(result.items(), reverse=True))
|
||||||
images.append({"name": name, "url": IMAGES[name]})
|
items_two = dict(sorted(result_net.items(), reverse=True))
|
||||||
for item in items_two:
|
images = []
|
||||||
name = items_two[item].replace("_", " ")
|
for item in items_one:
|
||||||
if name not in IMAGES:
|
name = items_one[item].replace("_", " ")
|
||||||
continue
|
if name not in IMAGES:
|
||||||
images.append({"name": name, "url": IMAGES[name]})
|
continue
|
||||||
return json_answer({
|
images.append({"name": name, "url": IMAGES[name]})
|
||||||
"results": items_one,
|
for item in items_two:
|
||||||
"results_net": items_two,
|
name = items_two[item].replace("_", " ")
|
||||||
"images": images,
|
if name not in IMAGES:
|
||||||
})
|
continue
|
||||||
|
images.append({"name": name, "url": IMAGES[name]})
|
||||||
|
return {
|
||||||
|
"results": items_one,
|
||||||
|
"results_net": items_two,
|
||||||
|
"images": images,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
sitemap_data = '''<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
|
|
||||||
|
|
||||||
|
class BaseController(Controller):
|
||||||
|
path = "/"
|
||||||
|
|
||||||
|
@get("/")
|
||||||
|
async def main(self) -> Template:
|
||||||
|
return Template(name="index.html")
|
||||||
|
|
||||||
|
@get("/sitemap.xml", media_type=MediaType.XML)
|
||||||
|
async def sitemaps(self) -> bytes:
|
||||||
|
return '''<?xml version="1.0" encoding="UTF-8"?>
|
||||||
<urlset
|
<urlset
|
||||||
xmlns="http://www.sitemaps.org/schemas/sitemap/0.9"
|
xmlns="http://www.sitemaps.org/schemas/sitemap/0.9"
|
||||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||||
|
|
@ -116,22 +128,27 @@ sitemap_data = '''<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
|
||||||
|
|
||||||
</urlset>
|
</urlset>
|
||||||
'''
|
'''.encode()
|
||||||
|
|
||||||
@app.get("/sitemap.xml")
|
|
||||||
async def sitemaps(request):
|
|
||||||
return text(sitemap_data, content_type="application/xml")
|
|
||||||
|
|
||||||
robots_data = '''
|
@get("/robots.txt", media_type=MediaType.TEXT)
|
||||||
|
async def robots(self) -> str:
|
||||||
|
return '''
|
||||||
User-agent: *
|
User-agent: *
|
||||||
Allow: /
|
Allow: /
|
||||||
|
|
||||||
Sitemap: https://xn-----6kcp3cadbabfh8a0a.xn--p1ai/sitemap.xml
|
Sitemap: https://xn-----6kcp3cadbabfh8a0a.xn--p1ai/sitemap.xml
|
||||||
'''
|
'''
|
||||||
|
|
||||||
@app.get("/robots.txt")
|
|
||||||
async def robots(request):
|
|
||||||
return text(robots_data)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
app = Starlite(
|
||||||
app.run(auto_reload=True, port=4003, host="0.0.0.0")
|
route_handlers=[BeerdsController, BaseController],
|
||||||
|
static_files_config=[
|
||||||
|
StaticFilesConfig(directories=["static"], path="/static"),
|
||||||
|
|
||||||
|
],
|
||||||
|
template_config=TemplateConfig(
|
||||||
|
directory=Path("templates"),
|
||||||
|
engine=JinjaTemplateEngine,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,7 @@
|
||||||
<p>Загрузите фото, чтобы опеределить породу собаки или щенка. Если порода смешанная (или порода определена неточно), после загрузки будет показана вероятность породы животного.</p>
|
<p>Загрузите фото, чтобы опеределить породу собаки или щенка. Если порода смешанная (или порода определена неточно), после загрузки будет показана вероятность породы животного.</p>
|
||||||
<p>Определение породы происходит при помощи нейронной сети - точность опеределения составляет 60%, сеть обучена на <a href="https://vk.com/albums-220240483" target="_blank">125 породах</a>. Если на фото будет неизвестная порода или не собака - сеть не сможет правильно опеределить, что это.</p>
|
<p>Определение породы происходит при помощи нейронной сети - точность опеределения составляет 60%, сеть обучена на <a href="https://vk.com/albums-220240483" target="_blank">125 породах</a>. Если на фото будет неизвестная порода или не собака - сеть не сможет правильно опеределить, что это.</p>
|
||||||
<p>Для распознования все фото отправляются на сервер, но там не сохраняются</p>
|
<p>Для распознования все фото отправляются на сервер, но там не сохраняются</p>
|
||||||
<form enctype="multipart/form-data" method="post" action="/beeds" onsubmit="SavePhoto();return false">
|
<form enctype="multipart/form-data" method="post" action="/breeds" onsubmit="SavePhoto();return false">
|
||||||
<p><input type="file" name="f" id="file-input">
|
<p><input type="file" name="f" id="file-input">
|
||||||
<input type="submit" value="Определить"></p>
|
<input type="submit" value="Определить"></p>
|
||||||
</form>
|
</form>
|
||||||
|
|
@ -1,12 +1,10 @@
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import requests
|
import requests # type: ignore
|
||||||
|
|
||||||
# Получить токен чтобы:
|
# Получить токен чтобы:
|
||||||
# https://oauth.vk.com/oauth/authorize?client_id=51534014&display=page&scope=photos,offline&response_type=token&v=5.131&slogin_h=4984535b54c59e09ca.f1e0b6dce0d0cc82e7&__q_hash=618f24fbac4bc34edbf09b8bc503e923
|
# https://oauth.vk.com/oauth/authorize?client_id=51534014&display=page&scope=photos,offline&response_type=token&v=5.131&slogin_h=4984535b54c59e09ca.f1e0b6dce0d0cc82e7&__q_hash=618f24fbac4bc34edbf09b8bc503e923
|
||||||
|
TOKEN = ""
|
||||||
#TOKEN = "vk1.a.mf4KFdN9gC14SSGDFHVwFRTpzBKBeNxkdlEe0IFlZqU5a5rHH5PwiPn5ekWnDhc94lEI5d2vtXzfxvjXRPapsQZCCt89YUwCIQB1alo06A0Iup9PCWbd6F5GayBn0TS_26N5BTQ1B7deFzi25BV3LKimP9g5ZkeoY0xhNfQ7XawPnBhhK0a2ipL5zZxygYgf"
|
|
||||||
TOKEN = "vk1.a.2VJFQn9oTIqfpVNcgk7OvxXU8TZPomCH4biRvZEZp8-tQTi8IdKajlXCY5vJbLFjjPGrRWpsM8wbG1mek2pVpktwqi1MGAFJQfQafg68buH7YiE3GtClgWhZNuDUX5PwQuANLRVh6Ao-DcN0Z-72AmWmsIKhf9A4yuE8q3O6Asn_miGvO9gUY_JpctKEVtAYIEhbJtQK7hxW8qpud8J5Vg"
|
|
||||||
VK_URL = "https://api.vk.com/method/"
|
VK_URL = "https://api.vk.com/method/"
|
||||||
headers = {"Authorization": f"Bearer {TOKEN}"}
|
headers = {"Authorization": f"Bearer {TOKEN}"}
|
||||||
postfix = "?v=5.131&state=123456"
|
postfix = "?v=5.131&state=123456"
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue