1. Transfer Learning
2. ResNet50 (Keras)
3. Dataset
4. Tutorial Code
1. Transfer Learning
전이 학습은 기존에 핟습된 모델을 다른 작업에 재사용하는 기법이며 기존 모델이 학습한 특징을 활용하여 새로운 작업에 대한 학습을 빠르고 효율적으로 수행할 수 있음
장점
- 학습 시간 단축: 기존 모델의 특징을 활용하여 학습을 시작할 수 있으므로, 새로운 작업에 대한 학습 시간을 단축할 수 있다
- 데이터 효율성 향상: 기존 모델이 학습한 데이터를 활용하여 학습을 수행할 수 있으므로, 새로운 작업에 대한 데이터 수집 및 전처리 비용을 절감할 수 있다.
- 성능 향상: 기존 모델의 특징을 활용하여 새로운 작업에 대한 성능을 향상시킬 수 있다.
단점
- 데이터의 유사성: 기존 모델이 학습한 데이터와 새로운 작업에 사용되는 데이터가 유사할수록 전이 학습의 효과가 높다.
- 모델의 종류: 기존 모델의 종류에 따라 전이 학습의 효과가 달라질 수 있다.
2. ResNet50 (Keras)
ResNet
- Resnet은 깊은 신경망의 학습 문제를 해결하기 위해 잔차 연결을 사용하는 신경망 구조임.
- 잔차 연결은 이전 층의 출력을 다음 층의 입력으로 직접 연결하여 깊은 신경망의 학습 속도와 성능을 향상시킴.
ResNet50 (Keras)
- Keras ResNet50은 Keras에서 제공하는 Resnet50 모델임.
- Keras ResNet50은 50개의 잔차 블록으로 구성되어 있으며, 이미지 분류, 객체 감지, 객체 인식 등 다양한 분야에서 활용될 수 있음.
3. Dataset (kaggle)
- Flower Classification
https://www.kaggle.com/datasets/sauravagarwal/flower-classification?resource=download
- Flower types - daisy, dandelion, roses, sunflowers, tulips
4. Tutorial Code
1) library
# python librairies installation
!pip install split-folders matplotlib opencv-python spicy
# display, transform, read, split ...
import numpy as np
import cv2 as cv
import os
import splitfolders
import matplotlib.pyplot as plt
# tensorflow
import tensorflow.keras as keras
import tensorflow as tf
# image processing
from tensorflow.keras.preprocessing import image
from tensorflow.keras.preprocessing.image import ImageDataGenerator, load_img
# model / neural network
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.applications.resnet50 import preprocess_input
2) Data Preprocessing
- check the images each type
# daisy
img_daisy = image.load_img("C:/Users/USER/Desktop/resnet50/data_full/daisy/100080576_f52e8ee070_n.jpg")
img_daisy
# dandelion
img_dandelion = image.load_img("C:/Users/USER/Desktop/resnet50/data_full/dandelion/10043234166_e6dd915111_n.jpg")
img_dandelion
# roses
img_roses = image.load_img("C:/Users/USER/Desktop/resnet50/data_full/roses/10090824183_d02c613f10_m.jpg")
img_roses
# sunflowers
img_sunflowers = image.load_img("C:/Users/USER/Desktop/resnet50/data_full/sunflowers/1008566138_6927679c8a.jpg")
img_sunflowers
# tulips
img_tulips = image.load_img("C:/Users/USER/Desktop/resnet50/data_full/tulips/100930342_92e8746431_n.jpg")
img_tulips
- spilt dataset
# split data in a new folder named data-split
splitfolders.ratio("C:/Users/USER/Desktop/resnet50/data_full",
output="C:/Users/USER/Desktop/resnet50/data_split",
seed=1337, ratio=(0.8, 0.1, 0.1), group_prefix=None, move=False)
- Create Keras data generators
# ImageDataGenerator()
datagen = ImageDataGenerator()
# Classes
# define classes name
class_names = ['daisy','dandelion','roses','sunflowers','tulips']
# Train set
# training data
train_generator = datagen.flow_from_directory(
directory="C:/Users/USER/Desktop/resnet50/data_split/train/",
classes = class_names,
target_size=(224, 224),
batch_size=32,
class_mode="binary",
)
# Validation set
# validation data
valid_generator = datagen.flow_from_directory(
directory="C:/Users/USER/Desktop/resnet50/data_split/val/",
classes = class_names,
target_size=(224, 224),
batch_size=32,
class_mode="binary",
)
# Test set
# test data
test_generator = datagen.flow_from_directory(
directory="C:/Users/USER/Desktop/resnet50/data_split/test/",
classes = class_names,
target_size=(224, 224),
batch_size=32,
class_mode="binary",
)
3) Build the morel, ResNet50
# ResNet50 model
resnet_50 = ResNet50(include_top=False, weights='imagenet', input_shape=(224,224,3))
for layer in resnet_50.layers:
layer.trainable = False
# build the entire model
x = resnet_50.output
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dense(512, activation='relu')(x)
x = layers.Dropout(0.5)(x)
x = layers.Dense(256, activation='relu')(x)
x = layers.Dropout(0.5)(x)
x = layers.Dense(128, activation='relu')(x)
x = layers.Dropout(0.5)(x)
x = layers.Dense(64, activation='relu')(x)
x = layers.Dropout(0.5)(x)
predictions = layers.Dense(5, activation='softmax')(x)
model = Model(inputs = resnet_50.input, outputs = predictions)
4) Training
- optimizer='adam', epochs=10
# define training function
def trainModel(model, epochs, optimizer):
batch_size = 32
model.compile(optimizer=optimizer, loss="sparse_categorical_crossentropy", metrics=["accuracy"])
return model.fit(train_generator, validation_data=valid_generator, epochs=epochs, batch_size=batch_size)
# launch the training
model_history = trainModel(model = model, epochs = 10, optimizer = "Adam")
# Loss curve
loss_train_curve = model_history.history["loss"]
loss_val_curve = model_history.history["val_loss"]
plt.plot(loss_train_curve, label = "Train")
plt.plot(loss_val_curve, label = "Validation")
plt.legend(loc = 'upper right')
plt.title("Loss")
plt.show()
# Accuracy curve
acc_train_curve = model_history.history["accuracy"]
acc_val_curve = model_history.history["val_accuracy"]
plt.plot(acc_train_curve, label = "Train")
plt.plot(acc_val_curve, label = "Validation")
plt.legend(loc = 'lower right')
plt.title("Accuracy")
plt.show()
5) Evaluation the performace
test_loss, test_acc = model.evaluate(test_generator)
print("The test loss is: ", test_loss)
print("The best accuracy is: ", test_acc*100)
6) Predict New data
img = tf.keras.preprocessing.image.load_img('C:/Users/USER/Desktop/resnet50/data_full/roses/123128873_546b8b7355_n.jpg',
target_size=(224, 224))
img_array = tf.keras.preprocessing.image.img_to_array(img)
img_array = np.array([img_array])
img
# generate predictions for samples
predictions = model.predict(img_array)
print(predictions)
# generate argmax for predictions
class_id = np.argmax(predictions, axis = 1)
print(class_id)
# transform classes number into classes name
print(class_names[class_id.item()])
7) Save model
# directory
model.save('C:Users/USER/Desktop/resnet50/saved_model/my_model')
# or h5 file
model.save('C:/Users/USER/Desktop/resnet50/saved_model/my_model.h5')
model = tf.keras.models.load_model('C:/Users/USER/Desktop/resnet50/saved_model/my_model')
model.summary()
- 전체 24809605 parameters 중에서 1221893 개 > 약 5%만 이용하여 모델을 설계함.
- Tresnfer learning > 굉장히 효율적인 접근방법