ぺ ん ぎ ん の 閃 き

閃き- blog

きらびやかに、美しく、痛烈に.

Keras/Tensorflow : CIFAR-10のVGG-likeなアーキテクチャを作った.

VGG

1. 動作環境

OS: Ubuntu 16.04

Package             Version  
------------------- -------
python              3.5.0
tensorboard         1.9.0    
tensorflow          1.9.0    
h5py                2.8.0    
Keras               2.2.2    
Keras-Applications  1.0.4    
Keras-Preprocessing 1.0.2  

2. プログラム

import os
import numpy as np
import keras.backend.tensorflow_backend as KTF
import tensorflow as tf
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation, Flatten
from keras.layers.normalization import BatchNormalization
from keras.layers import Conv2D, MaxPooling2D
from keras.optimizers import SGD, Adam
from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard, CSVLogger
from datetime import datetime

# Set Meta Parameters & Information
img_row = 32.0
img_col = 32.0
img_ch  = 255.0

batch_size = 128
nb_classes = 10
nb_epoch   = 200
nb_data    = 32*32

log_dir        = '../train_log/vgg-like_log'
dataset_dir    = '../../CIFAR-10/datasets'
model_name     = 'vgg-like__' + datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
model_cp_path  = os.path.join(log_dir, (model_name + '_checkpoint.h5'))
model_csv_path = os.path.join(log_dir, (model_name + '_csv.csv'))

# Load CIFAR-10 Dataset
argmax_ch  = 255.0
(X_train, y_train), (X_test, y_test) = cifar10.load_data()

# convert pixel value range 0.0-255.0 to 0.0-1.0
X_train = X_train.astype('float32') / img_ch
X_test  = X_test.astype('float32')  / img_ch

# convert class label (0-9) to one-hot encoding format
y_train = np_utils.to_categorical(y_train, nb_classes)
y_test  = np_utils.to_categorical(y_test, nb_classes)

# save datasets as "np.ndarray" format files
np.save('X_train', X_train)
np.save('y_train', y_train)
np.save('X_test' , X_test)
np.save('y_test' , y_test)

# Data Augumatation
datagen = ImageDataGenerator(
        featurewise_center=True,
        featurewise_std_normalization=True,
        rotation_range=20,
        width_shift_range=0.2,
        height_shift_range=0.2,
        horizontal_flip=True
        vertical_flip=False)
datagen.fit(X_train)



old_session = KTF.get_session()

with tf.Graph().as_default():
    session = tf.Session('')
    KTF.set_session(session)
    KTF.set_learning_phase(1)

    # build model
    model = Sequential()
    with tf.name_scope('inference') as scope:
        model.add(Conv2D(64, (3, 3), padding='same', input_shape=X_train.shape[1:]))
        model.add(BatchNormalization())
        model.add(Activation('relu'))
        model.add(Conv2D(64, (3, 3), padding='same'))
        model.add(BatchNormalization())
        model.add(Activation('relu'))
        model.add(MaxPooling2D(pool_size=(2, 2)))
        model.add(Dropout(0.25))

        model.add(Conv2D(128, (3, 3), padding='same'))
        model.add(BatchNormalization())
        model.add(Activation('relu'))
        model.add(Conv2D(128, (3, 3), padding='same'))
        model.add(BatchNormalization())
        model.add(Activation('relu'))
        model.add(MaxPooling2D(pool_size=(2, 2)))
        model.add(Dropout(0.25))

        model.add(Conv2D(256, (3, 3), padding='same'))
        model.add(BatchNormalization())
        model.add(Activation('relu'))
        model.add(Conv2D(256, (3, 3), padding='same'))
        model.add(BatchNormalization())
        model.add(Activation('relu'))
        model.add(Conv2D(256, (3, 3), padding='same'))
        model.add(BatchNormalization())
        model.add(Activation('relu'))
        model.add(Conv2D(256, (3, 3), padding='same'))
        model.add(BatchNormalization())
        model.add(Activation('relu'))
        model.add(MaxPooling2D(pool_size=(2, 2)))
        model.add(Dropout(0.25))

        model.add(Flatten())
        model.add(Dense(1024))
        model.add(Activation('relu'))
        model.add(Dropout(0.5))
        model.add(Dense(1024))
        model.add(Activation('relu'))
        model.add(Dropout(0.5))
        model.add(Dense(nb_classes))
        model.add(Activation('softmax'))
    model.summary()

    # set callbacks
    cp_cb  = ModelCheckpoint(model_cp_path, monitor='val_loss', save_best_only=True)
    tb_cb  = TensorBoard(log_dir=log_dir, histogram_freq=1)
    csv_cb = CSVLogger(model_csv_path) 
    callbacks = [cp_cb, tb_cb, csv_cb]

    # compile model
    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    history = model.fit_generator(
        datagen.flow(X_train, y_train, batch_size=batch_size),
        steps_per_epoch=len(X_train) / batch_size,
        epochs=nb_epoch,
        verbose=1,
        callbacks=callbacks,
        validation_data=(X_test, y_test))
    
    # validation
    score = model.evaluate(X_test, y_test, verbose=0)
    print('val score    : ', score[0])
    print('val accuracy : ', score[1])


# save model "INSTANCE"
f1_name = model_name + '_instance'
f1_path = os.path.join(log_dir, f1_name) + '.h5'
model.save(f1_path)

# save model "WEIGHTs"
f2_name = model_name + '_weights'
f2_path = os.path.join(log_dir, f2_name) + '.h5'
model.save_weights(f2_path)

# save model "ARCHITECHTURE"
f3_name = model_name + '_architechture'
f3_path = os.path.join(log_dir, f3_name) + '.json'
json_string = model.to_json()
with open(f3_path, 'w') as f:
    f.write(json_string)


# end of session
KTF.set_session(old_session)

3.アーキテクチャ

Kerasのkeras.models.Model()クラスもつ属性(attribute)である"summary()"を使う.
プログラム上のmodel.summary()で、標準出力にモデルの構造(architechture)の要約情報が表示される.

Modelクラス (functional API) - Keras Documentation

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_1 (Conv2D)            (None, 32, 32, 64)        1792      
_________________________________________________________________
batch_normalization_1 (Batch (None, 32, 32, 64)        256       
_________________________________________________________________
activation_1 (Activation)    (None, 32, 32, 64)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 32, 32, 64)        36928     
_________________________________________________________________
batch_normalization_2 (Batch (None, 32, 32, 64)        256       
_________________________________________________________________
activation_2 (Activation)    (None, 32, 32, 64)        0         
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 16, 16, 64)        0         
_________________________________________________________________
dropout_1 (Dropout)          (None, 16, 16, 64)        0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 16, 16, 128)       73856     
_________________________________________________________________
batch_normalization_3 (Batch (None, 16, 16, 128)       512       
_________________________________________________________________
activation_3 (Activation)    (None, 16, 16, 128)       0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 16, 16, 128)       147584    
_________________________________________________________________
batch_normalization_4 (Batch (None, 16, 16, 128)       512       
_________________________________________________________________
activation_4 (Activation)    (None, 16, 16, 128)       0         
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 8, 8, 128)         0         
_________________________________________________________________
dropout_2 (Dropout)          (None, 8, 8, 128)         0         
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 8, 8, 256)         295168    
_________________________________________________________________
batch_normalization_5 (Batch (None, 8, 8, 256)         1024      
_________________________________________________________________
activation_5 (Activation)    (None, 8, 8, 256)         0         
_________________________________________________________________
conv2d_6 (Conv2D)            (None, 8, 8, 256)         590080    
_________________________________________________________________
batch_normalization_6 (Batch (None, 8, 8, 256)         1024      
_________________________________________________________________
activation_6 (Activation)    (None, 8, 8, 256)         0         
_________________________________________________________________
conv2d_7 (Conv2D)            (None, 8, 8, 256)         590080    
_________________________________________________________________
batch_normalization_7 (Batch (None, 8, 8, 256)         1024      
_________________________________________________________________
activation_7 (Activation)    (None, 8, 8, 256)         0         
_________________________________________________________________
conv2d_8 (Conv2D)            (None, 8, 8, 256)         590080    
_________________________________________________________________
batch_normalization_8 (Batch (None, 8, 8, 256)         1024      
_________________________________________________________________
activation_8 (Activation)    (None, 8, 8, 256)         0         
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 4, 4, 256)         0         
_________________________________________________________________
dropout_3 (Dropout)          (None, 4, 4, 256)         0         
_________________________________________________________________
flatten_1 (Flatten)          (None, 4096)              0         
_________________________________________________________________
dense_1 (Dense)              (None, 1024)              4195328   
_________________________________________________________________
activation_9 (Activation)    (None, 1024)              0         
_________________________________________________________________
dropout_4 (Dropout)          (None, 1024)              0         
_________________________________________________________________
dense_2 (Dense)              (None, 1024)              1049600   
_________________________________________________________________
activation_10 (Activation)   (None, 1024)              0         
_________________________________________________________________
dropout_5 (Dropout)          (None, 1024)              0         
_________________________________________________________________
dense_3 (Dense)              (None, 10)                10250     
_________________________________________________________________
activation_11 (Activation)   (None, 10)                0         
=================================================================
Total params: 7,586,378
Trainable params: 7,583,562
Non-trainable params: 2,816
_________________________________________________________________


4.学習結果

validation accuracy は 90% くらい

f:id:yumaloop:20180822170332p:plain

f:id:yumaloop:20180822170409p:plain