閃 き

閃き- blog

きらびやかに、美しく、痛烈に.

Keras/Tensorflow : CIFAR-10のMobileNetV2-likeなアーキテクチャを作った

f:id:yumaloop:20181201151138p:plain
MobilenetV2-残差ブロックの構造

0. 前提

MobileNetV2は、2018/04にGoogleのMark Sandlerらによって発表された論文:"MobileNetV2: Inverted Residuals and Linear Bottlenecks"にて導入された、ニューラルネットワークモデルです。その名の通り、モバイル機器向けの高速なフォワード処理に適しており、精度を保持した上でパラメータの削減に成功しています。これは、同じくGoogleから発表された論文:"Mobilenets: Efficient Convolutional Neural Networks for Mobile Vision Applications"で提唱されたMobileNetアーキテクチャのアップグレード版という位置付けで、出力チャネル数を明に増加させる「Expand-layer」やResNet系で馴染み深い「skip-connection」の導入などが特徴的です。

github.com


1. 動作環境

OS: Ubuntu 16.04

Package             Version  
------------------- -------
python              3.5.0
tensorboard         1.9.0    
tensorflow          1.9.0    
h5py                2.8.0    
Keras               2.2.2    
Keras-Applications  1.0.4    
Keras-Preprocessing 1.0.2  

2. プログラム

import_CIFAR-10.py

CIFAR-10データセットをnumpy配列(ndarray)形式で保存しておく実行ファイルです. CIFAR-10の画像は一枚あたり「32w(pixel) × 32h(pixel) × 3ch(RGB)」個のpixelからできています. 今回は, 画像データに対する前処理としてRGB-channelに応じて標準化(平均0, 標準偏差1)を行います.

データセット用のディレクトリと実行ファイルを用意しておくとバージョン管理が楽です.

import numpy as np
from keras.datasets import cifar10
from keras.utils import np_utils

nb_classes = 10
argmax_ch  = 255.0

if __name__=='__main__':
    # load CIFAR-10 data
    (X_train, y_train), (X_test, y_test) = cifar10.load_data()

    # set data type as 'float32'
    X_train = X_train.astype('float32') #argmax_ch
    X_test  = X_test.astype('float32')  #argmax_ch

    def ch_wise_normalization(X_type, ch):
        mean_ch = X_type[:, :, :, ch].mean()
        std_ch = X_type[:, :, :, ch].std()
        X_type[:, :, :, ch] = (X_type[:, :, :, ch] - mean_ch) / std_ch
        return X_type[:, :, :, ch]

    # normalize data for each R-G-B(0, 1, 2) channel 
    X_train[:, :, :, 0] = ch_wise_normalization(X_train, 0)
    X_train[:, :, :, 1] = ch_wise_normalization(X_train, 1)
    X_train[:, :, :, 2] = ch_wise_normalization(X_train, 2)

    X_test[:, :, :, 0]  = ch_wise_normalization(X_test, 0)
    X_test[:, :, :, 1]  = ch_wise_normalization(X_test, 1)
    X_test[:, :, :, 2]  = ch_wise_normalization(X_test, 2)

    # convert class label (0-9) to one-hot encoding format
    y_train = np_utils.to_categorical(y_train, nb_classes)
    y_test  = np_utils.to_categorical(y_test, nb_classes)

    # save datasets as np.ndarray class format files
    np.save('X_train', X_train)
    np.save('y_train', y_train)
    np.save('X_test' , X_test)
    np.save('y_test' , y_test)

mobilenetv2.py

MobileNetV2のアーキテクチャを定義してそのインスタンスを返すためのファイルです.
Kerasではkeras.applications.mobilenetv2.MobileNetV2で、定義ずみアーキテクチャの利用が可能なのですが, CIFAR-10, CIFAR-100の画像データは一片が32 pixelと非常に小さく、一辺が224 pixelで構成されるImageNet用に書かれている原論文のモデルでは, うまく学習ができません. そのため, Githubの実装を参考にして, アーキテクチャを作りました.

import os
import warnings
import numpy as np
from keras.layers import Input, Activation, Conv2D, Dense, Dropout, BatchNormalization, ReLU, DepthwiseConv2D, GlobalAveragePooling2D, GlobalMaxPooling2D, Add
from keras.models import Model
from keras import regularizers

# define the filter size
def _make_divisible(v, divisor, min_value=None):
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v


# define the calcuration of each 'Res_Block'
def _inverted_res_block(inputs, expansion, stride, alpha, filters, block_id):
    prefix = 'block_{}_'.format(block_id)

    in_channels = inputs._keras_shape[-1]
    pointwise_conv_filters = int(filters * alpha)
    pointwise_filters = _make_divisible(pointwise_conv_filters, 8)
    x = inputs

    # Expand
    if block_id:
        x = Conv2D(expansion * in_channels, kernel_size=1, strides=1, padding='same', use_bias=False, activation=None, kernel_initializer="he_normal", kernel_regularizer=regularizers.l2(4e-5), name=prefix + 'expand')(x)
        x = BatchNormalization(epsilon=1e-3, momentum=0.999, name=prefix + 'expand_BN')(x)
        x = ReLU(6., name=prefix + 'expand_relu')(x)
    else:
        prefix = 'expanded_conv_'

    # Depthwise
    x = DepthwiseConv2D(kernel_size=3, strides=stride, activation=None, use_bias=False, padding='same', kernel_initializer="he_normal", depthwise_regularizer=regularizers.l2(4e-5), name=prefix + 'depthwise')(x)
    x = BatchNormalization(epsilon=1e-3, momentum=0.999, name=prefix + 'depthwise_BN')(x)
    x = ReLU(6., name=prefix + 'depthwise_relu')(x)

    # Project
    x = Conv2D(pointwise_filters, kernel_size=1, strides=1, padding='same', use_bias=False, activation=None, kernel_initializer="he_normal", kernel_regularizer=regularizers.l2(4e-5), name=prefix + 'project')(x)
    x = BatchNormalization(epsilon=1e-3, momentum=0.999, name=prefix + 'project_BN')(x)


    if in_channels == pointwise_filters and stride == 1:
        return Add(name=prefix + 'add')([inputs, x])
    return x

# build MobileNetV2 models
def MobileNetV2(input_shape=(32, 32, 3),
                alpha=1.0,
                depth_multiplier=1,
                include_top=True,
                pooling=None,
                classes=10):

    # fileter size (first block)
    first_block_filters = _make_divisible(32 * alpha, 8)
    # input shape  (first block)
    img_input = Input(shape=input_shape)

    # model architechture
    x = Conv2D(first_block_filters, kernel_size=3, strides=1, padding='same', use_bias=False, kernel_initializer="he_normal", kernel_regularizer=regularizers.l2(4e-5), name='Conv1')(img_input)
    #x = BatchNormalization(epsilon=1e-3, momentum=0.999, name='bn_Conv1')(x)
    #x = ReLU(6., name='Conv1_relu')(x)

    x = _inverted_res_block(x, filters=16,  alpha=alpha, stride=1, expansion=1, block_id=0 )

    x = _inverted_res_block(x, filters=24,  alpha=alpha, stride=1, expansion=6, block_id=1 )
    x = _inverted_res_block(x, filters=24,  alpha=alpha, stride=1, expansion=6, block_id=2 )

    x = _inverted_res_block(x, filters=32,  alpha=alpha, stride=2, expansion=6, block_id=3 )
    x = _inverted_res_block(x, filters=32,  alpha=alpha, stride=1, expansion=6, block_id=4 )
    x = _inverted_res_block(x, filters=32,  alpha=alpha, stride=1, expansion=6, block_id=5 )

    x = _inverted_res_block(x, filters=64,  alpha=alpha, stride=2, expansion=6, block_id=6 )
    x = _inverted_res_block(x, filters=64,  alpha=alpha, stride=1, expansion=6, block_id=7 )
    x = _inverted_res_block(x, filters=64,  alpha=alpha, stride=1, expansion=6, block_id=8 )
    x = _inverted_res_block(x, filters=64,  alpha=alpha, stride=1, expansion=6, block_id=9 )

    x = _inverted_res_block(x, filters=96,  alpha=alpha, stride=1, expansion=6, block_id=10)
    x = _inverted_res_block(x, filters=96,  alpha=alpha, stride=1, expansion=6, block_id=11)
    x = _inverted_res_block(x, filters=96,  alpha=alpha, stride=1, expansion=6, block_id=12)

    x = _inverted_res_block(x, filters=160, alpha=alpha, stride=2, expansion=6, block_id=13)
    x = _inverted_res_block(x, filters=160, alpha=alpha, stride=1, expansion=6, block_id=14)
    x = _inverted_res_block(x, filters=160, alpha=alpha, stride=1, expansion=6, block_id=15)
    x = Dropout(rate=0.25)(x)

    x = _inverted_res_block(x, filters=320, alpha=alpha, stride=1, expansion=6, block_id=16)
    x = Dropout(rate=0.25)(x)

    # define fileter size (last block)
    if alpha > 1.0:
        last_block_filters = _make_divisible(1280 * alpha, 8)
    else:
        last_block_filters = 1280


    x = Conv2D(last_block_filters, kernel_size=1, use_bias=False, kernel_initializer="he_normal", kernel_regularizer=regularizers.l2(4e-5), name='Conv_1')(x)
    x = BatchNormalization(epsilon=1e-3, momentum=0.999, name='Conv_1_bn')(x)
    x = ReLU(6., name='out_relu')(x)
    
    # top layer ("use" or "not use" FC)
    if include_top:
        x = GlobalAveragePooling2D(name='global_average_pool')(x)
        x = Dense(classes, activation='softmax', use_bias=True, name='Logits')(x)
    else:
        if pooling == 'avg':
            x = GlobalAveragePooling2D()(x)
        elif pooling == 'max':
            x = GlobalMaxPooling2D()(x)

    # create model of MobileNetV2 (for CIFAR-10)
    model = Model(inputs=img_input, outputs=x, name='mobilenetv2_cifar10')
    return model

main.py

各ファイルを統合して, 学習を実行するメインファイルです. このファイルによる実行がトリガーとなって, 各処理が進みます.

import os
import h5py
import numpy as np
import keras.backend.tensorflow_backend as KTF
import tensorflow as tf
from keras.models import Sequential, Model
from keras.applications.mobilenetv2 import MobileNetV2
from keras.engine.topology import Input
from keras.layers.core import Dense, Dropout, Activation, Flatten
from keras.layers.normalization import BatchNormalization
from keras.layers import Conv2D, SeparableConv2D, MaxPooling2D
from keras.optimizers import SGD, Adam, RMSprop
from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard, CSVLogger, LearningRateScheduler
from datetime import datetime

from datafeeders.datafeeder import DataFeeder
from app.mobilenet_v2 import MobileNetV2

# set meta params
batch_size = 128
nb_classes = 10
nb_epoch   = 300
nb_data    = 32*32

# set meta info
log_dir         = '../train_log/mobilenet_v2-like5_log'
dataset_dir     = '../../CIFAR-10/datasets/dataset_norm'
model_name      = 'mobilenet_v2-like5__' + datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
model_arch_path = os.path.join(log_dir, (model_name + '_arch.png'))
model_cp_path   = os.path.join(log_dir, (model_name + '_checkpoint.h5'))
model_csv_path  = os.path.join(log_dir, (model_name + '_csv.csv'))


# load data
DF = DataFeeder(dataset_dir)
X_train, y_train, X_test, y_test = DF.X_train, DF.y_train, DF.X_test, DF.y_test

# data augumatation
datagen = ImageDataGenerator(
        width_shift_range=0.2, 
        height_shift_range=0.2,
        vertical_flip=False,
        horizontal_flip=True
        )
datagen.fit(X_train)


# build model
model = MobileNetV2(input_shape=X_train.shape[1:], include_top=True, alpha=1.0)
model.summary()
print('Model Name: ', model_name)

# save model architechture plot (.png)
from keras.utils import plot_model
plot_model(model, to_file=model_arch_path, show_shapes=True)

# set learning rate
learning_rates=[]
for i in range(5):
    learning_rates.append(2e-2)
for i in range(50-5):
    learning_rates.append(1e-2)
for i in range(100-50):
    learning_rates.append(8e-3)
for i in range(150-100):
    learning_rates.append(4e-3)
for i in range(200-150):
    learning_rates.append(2e-3)
for i in range(300-200):
    learning_rates.append(1e-3)

# set callbacks
callbacks = []
callbacks.append(TensorBoard(log_dir=log_dir, histogram_freq=1))
callbacks.append(ModelCheckpoint(model_cp_path, monitor='val_loss', save_best_only=True))
callbacks.append(LearningRateScheduler(lambda epoch: float(learning_rates[epoch])))
callbacks.append(CSVLogger(model_csv_path)) 

# compile & learning model
model.compile(loss='categorical_crossentropy', optimizer=SGD(lr=2e-2, momentum=0.9, decay=0.0, nesterov=False), metrics=['accuracy'])
history = model.fit_generator(
              datagen.flow(X_train, y_train, batch_size=batch_size),
              steps_per_epoch=len(X_train) / batch_size,
              epochs=nb_epoch,
              verbose=1,
              callbacks=callbacks,
              validation_data=(X_test, y_test))
   
# validation
val_loss, val_acc = model.evaluate(X_test, y_test, verbose=0)
print('Model Name: ', model_name)
print('Test loss     : {:.5f}'.format(val_loss))
print('Test accuracy : {:.5f}'.format(val_acc))

# save model "INSTANCE"
f1_name = model_name + '_instance'
f1_path = os.path.join(log_dir, f1_name) + '.h5'
model.save(f1_path)

# save model "WEIGHTs"
f2_name = model_name + '_weights'
f2_path = os.path.join(log_dir, f2_name) + '.h5'
model.save_weights(f2_path)

# save model "ARCHITECHTURE"
f3_name = model_name + '_architechture'
f3_path = os.path.join(log_dir, f3_name) + '.json'
json_string = model.to_json()
with open(f3_path, 'w') as f:
    f.write(json_string)

# save plot figure
from utils.plot_log import save_plot_log
save_plot_log(log_dir, model_csv_path, index='acc')
save_plot_log(log_dir, model_csv_path, index='loss')


3. 学習結果

最終的な Validation Accuracyは 91% くらいでした。パラメータ数が250万弱なのでなかなかの性能。
学習率のスケジューリングによって精度が±5%動くので、難しいです。
※lossグラフの縦軸に「crossentropy」とありますが、多変量なので正しくは「categorical crossentropy」です。さーせん。

f:id:yumaloop:20180910124348p:plain
acc

f:id:yumaloop:20180910124359p:plain
loss