본문 바로가기

데이터 다루기/Vision (Tensorflow)

[Vision] Transfer Learning + Fine Tuning

728x90
반응형

이번 포스팅에서는 사전 훈련된 모델을 가지고 와서, 새로운 데이터에 학습시키는 Transfer learning에 대해서 실습하도록 하겠습니다.

총 2가지 방법으로 진행 해보도록 하겠습니다.

(1) 특징 추출

새 샘플에서 의미 있는 형상을 추출하기 위해 이전 네트워크에서 학습한 표현을 사용합니다. 사전 훈련된 모델 위에 처음부터 교육할 새 분류기를 추가하기만 하면 이전에 데이터셋으로 학습한 특징 맵의 용도를 재사용할 수 있습니다.

전체 모델을 재훈련시킬 필요는 없습니다. 기본 컨볼루션 네트워크에는 그림 분류에 일반적으로 유용한 기능이 이미 포함되어 있습니다. 그러나 사전 훈련된 모델의 최종 분류 부분은 기존의 분류 작업에 따라 다르며 이후에 모델이 훈련된 클래스 집합에 따라 다릅니다.

(2) 미세 조정 (Fine tuning)

고정된 기본 모델의 일부 최상위 층을 고정 해제하고 새로 추가 된 분류기 층과 기본 모델의 마지막 층을 함께 훈련시킵니다. 이를 통해 기본 모델에서 고차원 특징 표현을 "미세 조정"하여 특정 작업에 보다 관련성이 있도록 할 수 있습니다.

1. 사용할 패키지 불러오기.

from __future__ import absolute_import, division, print_function, unicode_literals
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds
tfds.disable_progress_bar()

import를 활용하여 사용할 패키지를 불어왔습니다.

2. 데이터 불러오기.

이번 실습에서 사용할 데이터는 개 vs 고양이 분류 사진입니다.

(raw_train, raw_validation, raw_test), metadata = tfds.load(
    'cats_vs_dogs',
    split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'],
    with_info=True,
    as_supervised=True,
)

실제로 데이터셋에 어떠한 이미지가 있는지 살펴보도록 하겠습니다.

get_label_name = metadata.features['label'].int2str

for image, label in raw_train.take(2):
  plt.figure()
  plt.imshow(image)
  plt.title(get_label_name(label))

귀여운 강아지 사진들이 보이네요!

저희는 이러한 사진들을 보고 dog 인지 아니면 cat 인지를 예측해야 합니다.

3. 데이터 전처리

데이터 분석 프로세스에 있어서 모델링 전에 반드시 진행해야 하는 절차가 있습니다.

바로 데이터 전처리입니다.

우선 이미지를 고정된 Size로 설정해주고, -1 ~ 1 사이의 값을 가지도록 Normalize 시키겠습니다.

IMG_SIZE = 160 # 모든 이미지는 160x160으로 크기가 조정됩니다

def format_example(image, label):
  image = tf.cast(image, tf.float32)
  image = (image/127.5) - 1
  image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
  return image, label

이 때 map 함수를 사용하여 모든 이미지에 대해서 같은 전처리를 해줄 수 있습니다.

train = raw_train.map(format_example)
validation = raw_validation.map(format_example)
test = raw_test.map(format_example)

그리고 데이터의 Batch를 정의하도록 하겠습니다.

BATCH_SIZE = 32
SHUFFLE_BUFFER_SIZE = 1000

train_batches = train.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)
validation_batches = validation.batch(BATCH_SIZE)
test_batches = test.batch(BATCH_SIZE)
for image_batch, label_batch in train_batches.take(1):
   pass

image_batch.shape

TensorShape([32, 160, 160, 3])

위에 보시면 데이터가 예상대로 잘 전처리 된 것을 보실 수 있습니다.

4. 사전 학습된 모델 (Transfer Learning)

저희는 가벼운 CNN 네트워크로 많이 사용되는 MobileNetV2를 사전학습 모델로 불러오도록 하겠습니다.

IMG_SHAPE = (IMG_SIZE, IMG_SIZE, 3)

# 사전 훈련된 모델 MobileNet V2에서 기본 모델을 생성합니다.
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
                                               include_top=False,
                                               weights='imagenet')
base_model.summary()

Model: "mobilenetv2_1.00_160"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, 160, 160, 3) 0                                            
__________________________________________________________________________________________________
Conv1 (Conv2D)                  (None, 80, 80, 32)   864         input_1[0][0]                    
__________________________________________________________________________________________________
bn_Conv1 (BatchNormalization)   (None, 80, 80, 32)   128         Conv1[0][0]                      
__________________________________________________________________________________________________
Conv1_relu (ReLU)               (None, 80, 80, 32)   0           bn_Conv1[0][0]                   
__________________________________________________________________________________________________
expanded_conv_depthwise (Depthw (None, 80, 80, 32)   288         Conv1_relu[0][0]                 
__________________________________________________________________________________________________
expanded_conv_depthwise_BN (Bat (None, 80, 80, 32)   128         expanded_conv_depthwise[0][0]    
__________________________________________________________________________________________________
expanded_conv_depthwise_relu (R (None, 80, 80, 32)   0           expanded_conv_depthwise_BN[0][0] 
__________________________________________________________________________________________________
expanded_conv_project (Conv2D)  (None, 80, 80, 16)   512         expanded_conv_depthwise_relu[0][0
__________________________________________________________________________________________________
expanded_conv_project_BN (Batch (None, 80, 80, 16)   64          expanded_conv_project[0][0]      
__________________________________________________________________________________________________
block_1_expand (Conv2D)         (None, 80, 80, 96)   1536        expanded_conv_project_BN[0][0]   
__________________________________________________________________________________________________
block_1_expand_BN (BatchNormali (None, 80, 80, 96)   384         block_1_expand[0][0]             
__________________________________________________________________________________________________
block_1_expand_relu (ReLU)      (None, 80, 80, 96)   0           block_1_expand_BN[0][0]          
__________________________________________________________________________________________________
block_1_pad (ZeroPadding2D)     (None, 81, 81, 96)   0           block_1_expand_relu[0][0]        
__________________________________________________________________________________________________
block_1_depthwise (DepthwiseCon (None, 40, 40, 96)   864         block_1_pad[0][0]                
__________________________________________________________________________________________________
block_1_depthwise_BN (BatchNorm (None, 40, 40, 96)   384         block_1_depthwise[0][0]          
__________________________________________________________________________________________________
block_1_depthwise_relu (ReLU)   (None, 40, 40, 96)   0           block_1_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_1_project (Conv2D)        (None, 40, 40, 24)   2304        block_1_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_1_project_BN (BatchNormal (None, 40, 40, 24)   96          block_1_project[0][0]            
__________________________________________________________________________________________________
block_2_expand (Conv2D)         (None, 40, 40, 144)  3456        block_1_project_BN[0][0]         
__________________________________________________________________________________________________
block_2_expand_BN (BatchNormali (None, 40, 40, 144)  576         block_2_expand[0][0]             
__________________________________________________________________________________________________
block_2_expand_relu (ReLU)      (None, 40, 40, 144)  0           block_2_expand_BN[0][0]          
__________________________________________________________________________________________________
block_2_depthwise (DepthwiseCon (None, 40, 40, 144)  1296        block_2_expand_relu[0][0]        
__________________________________________________________________________________________________
block_2_depthwise_BN (BatchNorm (None, 40, 40, 144)  576         block_2_depthwise[0][0]          
__________________________________________________________________________________________________
block_2_depthwise_relu (ReLU)   (None, 40, 40, 144)  0           block_2_depthwise_BN[0][0]     

Summary를 통해 MobileNetV2의 구조를 보실 수 있습니다.

많은 레이어가 존재하는 것을 보실 수 있습니다.

이 모든 레이어를 학습하려면 사실상 시간이 오래걸리기 때문에, 사전 학습된 모델을 사용하는 의미가 없습니다.

실제로 이미 학습된 모델으로, 이미지의 특징을 잘 잡아낼 수 있기 때문에, 또 학습을 해줄 필요가 없습니다.

base_model.trainable = False

trainable = False로 지정해주어서, 이들을 학습되지 않게 할 수 있습니다.

이제 저희가 수행하려고 하는 Cat vs Dog Task에 맞게 분류층을 제일 끝 단에 추가하여야 합니다.

feature_batch = base_model(image_batch)
print(feature_batch.shape)

(32, 5, 5, 1280)

이미지 배치를 base model에 넣으면 (32,5,5,1280) 차원의 array로 반환됩니다.

우선 앞의 (5,5) 차원을 없애도록 하겠습니다.

global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
feature_batch_average = global_average_layer(feature_batch)
print(feature_batch_average.shape)

(32, 1280)

GlobalAveragePooling2D 레이어로 (32,1280) 차원으로 바꾸었습니다.

이제 이진 분류 예측을 위한 layer를 추가해보겠습니다.

prediction_layer = keras.layers.Dense(1)
prediction_batch = prediction_layer(feature_batch_average)
print(prediction_batch.shape)

(32, 1)

Dense (Fully connected layer)로 분류 예측을 하도록 만들었습니다.

model = tf.keras.Sequential([
  base_model,
  global_average_layer,
  prediction_layer
])

이제 이 두 레이어를 base_model의 끝부분에 추가하였습니다.

base_learning_rate = 0.0001
model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=base_learning_rate),
              loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=['accuracy'])

모델을 학습하도록 하겠습니다.

Optimizer는 RMSprop을 사용하였고, Loss는 BinaryCrossentropy입니다.

history = model.fit(train_batches,
                    epochs=initial_epochs,
                    validation_data=validation_batches)

Epoch 1/10
582/582 [==============================] - 40s 62ms/step - loss: 0.2151 - accuracy: 0.9088 - val_loss: 0.0863 - val_accuracy: 0.9695
Epoch 2/10
582/582 [==============================] - 37s 61ms/step - loss: 0.0704 - accuracy: 0.9747 - val_loss: 0.0630 - val_accuracy: 0.9764
Epoch 3/10
582/582 [==============================] - 37s 62ms/step - loss: 0.0568 - accuracy: 0.9794 - val_loss: 0.0554 - val_accuracy: 0.9781
Epoch 4/10
582/582 [==============================] - 37s 62ms/step - loss: 0.0510 - accuracy: 0.9814 - val_loss: 0.0516 - val_accuracy: 0.9798
Epoch 5/10
582/582 [==============================] - 37s 62ms/step - loss: 0.0477 - accuracy: 0.9829 - val_loss: 0.0496 - val_accuracy: 0.9811
Epoch 6/10
582/582 [==============================] - 38s 63ms/step - loss: 0.0455 - accuracy: 0.9837 - val_loss: 0.0482 - val_accuracy: 0.9815
Epoch 7/10
582/582 [==============================] - 38s 62ms/step - loss: 0.0438 - accuracy: 0.9840 - val_loss: 0.0472 - val_accuracy: 0.9824
Epoch 8/10
582/582 [==============================] - 37s 61ms/step - loss: 0.0426 - accuracy: 0.9838 - val_loss: 0.0465 - val_accuracy: 0.9837
Epoch 9/10
582/582 [==============================] - 36s 59ms/step - loss: 0.0416 - accuracy: 0.9846 - val_loss: 0.0460 - val_accuracy: 0.9850
Epoch 10/10
582/582 [==============================] - 35s 58ms/step - loss: 0.0406 - accuracy: 0.9848 - val_loss: 0.0456 - val_accuracy: 0.9850

보시다시피 끝단만 훈련을 시키는데도 성능이 굉장히 잘나옵니다.

즉, 새롭게 훈련을 시키지 않더라도 기존에 미리 학습된 모델이 이미지의 특징을 잘 추출하고 있다는 것입니다.

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.ylabel('Accuracy')
plt.ylim([min(plt.ylim()),1])
plt.title('Training and Validation Accuracy')

plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.ylabel('Cross Entropy')
plt.ylim([0,1.0])
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()

5. 미세조정 (Fine tuning)

위 처럼 끝단만 추가해도 성능이 잘나오지만, 더 좋은 성능을 내기 위한 방법으로 Fine tuning이 있습니다.

base_model.trainable = True

우선 전체 모델의 trainable을 True로 바꿔줍시다.

print("Number of layers in the base model: ", len(base_model.layers))

Number of layers in the base model:  154

모델의 전체 레이어는 154개입니다.

저희는 앞 100단은 고정을 시키고, 뒤 54단의 고정만 풀고 학습시키도록 하겠습니다.

# 해당 층 이후부터 미세 조정
fine_tune_at = 100

# `fine_tune_at` 층 이전의 모든 층을 고정
for layer in base_model.layers[:fine_tune_at]:
  layer.trainable =  False

위의 코드로 100단 이후만 trainable이 True 남아있게 됩니다.

model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              optimizer = tf.keras.optimizers.RMSprop(lr=base_learning_rate/10),
              metrics=['accuracy'])

학습된 모델을 훼손시키지 않기 위해서 좀 더 낮은 learning rate를 채택하였습니다.

fine_tune_epochs = 10
total_epochs =  initial_epochs + fine_tune_epochs

history_fine = model.fit(train_batches,
                         epochs=total_epochs,
                         initial_epoch =  history.epoch[-1],
                         validation_data=validation_batches)

모델을 학습하겠습니다.

Epoch 10/20
582/582 [==============================] - 44s 66ms/step - loss: 0.2044 - accuracy: 0.9270 - val_loss: 0.0815 - val_accuracy: 0.9789
Epoch 11/20
582/582 [==============================] - 39s 65ms/step - loss: 0.0704 - accuracy: 0.9704 - val_loss: 0.0547 - val_accuracy: 0.9824
Epoch 12/20
582/582 [==============================] - 40s 67ms/step - loss: 0.0544 - accuracy: 0.9809 - val_loss: 0.0481 - val_accuracy: 0.9824
Epoch 13/20
582/582 [==============================] - 40s 66ms/step - loss: 0.0446 - accuracy: 0.9837 - val_loss: 0.0509 - val_accuracy: 0.9785
Epoch 14/20
582/582 [==============================] - 41s 68ms/step - loss: 0.0329 - accuracy: 0.9863 - val_loss: 0.0490 - val_accuracy: 0.9819
Epoch 15/20
582/582 [==============================] - 41s 68ms/step - loss: 0.0231 - accuracy: 0.9920 - val_loss: 0.0510 - val_accuracy: 0.9811
Epoch 16/20
582/582 [==============================] - 41s 68ms/step - loss: 0.0171 - accuracy: 0.9936 - val_loss: 0.0493 - val_accuracy: 0.9824
Epoch 17/20
582/582 [==============================] - 41s 69ms/step - loss: 0.0159 - accuracy: 0.9947 - val_loss: 0.0521 - val_accuracy: 0.9850
Epoch 18/20
582/582 [==============================] - 41s 68ms/step - loss: 0.0111 - accuracy: 0.9961 - val_loss: 0.0527 - val_accuracy: 0.9845
Epoch 19/20
582/582 [==============================] - 41s 69ms/step - loss: 0.0087 - accuracy: 0.9975 - val_loss: 0.0579 - val_accuracy: 0.9837
Epoch 20/20
582/582 [==============================] - 41s 69ms/step - loss: 0.0089 - accuracy: 0.9975 - val_loss: 0.0565 - val_accuracy: 0.9841

성능이 더 올라간 것을 확인할 수 있습니다.

acc += history_fine.history['accuracy']
val_acc += history_fine.history['val_accuracy']

loss += history_fine.history['loss']
val_loss += history_fine.history['val_loss']

plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.ylim([0.8, 1])
plt.plot([initial_epochs-1,initial_epochs-1],
          plt.ylim(), label='Start Fine Tuning')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.ylim([0, 1.0])
plt.plot([initial_epochs-1,initial_epochs-1],
         plt.ylim(), label='Start Fine Tuning')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()

반응형