100 lines
2.8 KiB
Python
100 lines
2.8 KiB
Python
|
|
import os
|
|
import random
|
|
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
from tensorflow import keras
|
|
from tensorflow.keras import layers
|
|
from tensorflow.keras.utils import image_dataset_from_directory, split_dataset
|
|
|
|
|
|
img_size = (180, 180)
|
|
|
|
conv_base = keras.applications.vgg16.VGG16(
|
|
weights="imagenet",
|
|
include_top=False,
|
|
input_shape=(180, 180, 3))
|
|
conv_base.trainable = False
|
|
|
|
|
|
# обогащение выборки
|
|
data_augmentation = keras.Sequential(
|
|
[
|
|
layers.RandomFlip("horizontal"),
|
|
layers.RandomRotation(0.1),
|
|
layers.RandomZoom(0.2),
|
|
])
|
|
|
|
|
|
|
|
input_dir = "assets/dog"
|
|
|
|
labels_dict = {}
|
|
for fname in os.listdir(input_dir):
|
|
if fname in labels_dict:
|
|
continue
|
|
labels_dict[fname] = len(labels_dict)
|
|
|
|
model_name = "beerd_imagenet_25_04_2023.keras"
|
|
train_dataset, val_ds = image_dataset_from_directory(
|
|
input_dir,
|
|
labels="inferred",
|
|
label_mode="categorical",
|
|
class_names=None,
|
|
color_mode="rgb",
|
|
batch_size=32,
|
|
seed=12,
|
|
image_size=img_size,
|
|
shuffle=True,
|
|
validation_split=0.1,
|
|
subset="both",
|
|
interpolation="bilinear",
|
|
follow_links=False,
|
|
crop_to_aspect_ratio=False
|
|
)
|
|
|
|
validation_dataset, test_dataset = split_dataset(val_ds, left_size=0.8)
|
|
|
|
inputs = keras.Input(shape=(180, 180, 3))
|
|
x = data_augmentation(inputs)
|
|
x = keras.applications.vgg16.preprocess_input(x)
|
|
x = conv_base(x)
|
|
x = layers.Flatten()(x)
|
|
x = layers.Dense(512)(x)
|
|
x = layers.Dropout(0.5)(x)
|
|
outputs = layers.Dense(len(labels_dict), activation="softmax")(x)
|
|
model = keras.Model(inputs, outputs)
|
|
|
|
model.compile(optimizer="rmsprop",
|
|
loss="categorical_crossentropy", metrics=['accuracy'])
|
|
callbacks = [
|
|
keras.callbacks.ModelCheckpoint(model_name,
|
|
save_best_only=True)
|
|
]
|
|
history = model.fit(train_dataset,
|
|
epochs=100,
|
|
callbacks=callbacks,
|
|
validation_data=validation_dataset,)
|
|
|
|
epochs = range(1, len(history.history["loss"]) + 1)
|
|
loss = history.history["loss"]
|
|
val_loss = history.history["val_loss"]
|
|
acc = history.history["accuracy"]
|
|
val_acc = history.history["val_accuracy"]
|
|
plt.plot(epochs, acc, "bo", label="Точность на этапе обучения")
|
|
plt.plot(epochs, val_acc, "b", label="Точность на этапе проверки")
|
|
plt.title("Точность на этапах обучения и проверки")
|
|
plt.legend()
|
|
plt.figure()
|
|
plt.plot(epochs, loss, "bo", label="Потери на этапе обучения")
|
|
plt.plot(epochs, val_loss, "b", label="Потери на этапе проверки")
|
|
plt.title("Потери на этапах обучения и проверки")
|
|
plt.legend()
|
|
plt.show()
|
|
|
|
|
|
test_model = keras.models.load_model(model_name)
|
|
test_loss, test_acc = test_model.evaluate(test_dataset)
|
|
print(f"Test accuracy: {test_acc:.3f}") |