ぺ ん ぎ ん の 閃 き

閃き- blog

きらびやかに、美しく、痛烈に.

Tensorflow : MNISTを小規模なCNNで解いてみる

TensorFlowの練習がてら。
分類精度(accuracy)は98.9%

The simple implementation in python 3.6.6 with Tensorflow 1.9.0.


1.プログラム

# coding: utf-8

import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets import mnist
from tensorflow.contrib.learn.python.learn.metric_spec import MetricSpec

learn = tf.contrib.learn
slim  = tf.contrib.slim


# モデル定義(Convolutional Neural Network)
def cnn(x, y):
    x = tf.reshape(x, [-1, 28, 28, 1])
    y = slim.one_hot_encoding(y, 10)

    #data→conv→pool→conv→pool→full→full→(softmax)→cls
    net = slim.conv2d(x,   48, [5, 5], scope = 'conv1')
    net = slim.max_pool2d(net, [2, 2], scope = 'pool1')
    net = slim.conv2d(net, 96, [5, 5], scope = 'conv2')
    net = slim.max_pool2d(net, [2, 2], scope = 'pool2')
    net = slim.flatten(net, scope = 'flatten')
    net = slim.fully_connected(net, 512, scope = 'fully_connected1')
    logits = slim.fully_connected(net, 10, activation_fn = None, scope = 'fully_connected2')
    prob = slim.softmax(logits)
    loss = slim.losses.softmax_cross_entropy(logits, y)
    train_op = slim.optimize_loss(loss, slim.get_global_step(), learning_rate = 0.001, optimizer = 'Adam')
    return {'class': tf.argmax(prob, 1), 'prob': prob}, loss, train_op


# データの読み込み
data_sets = mnist.read_data_sets('/tmp/mnist', one_hot = False)

# 変数のセット
X_train = data_sets.train.images
Y_train = data_sets.train.labels
X_test = data_sets.validation.images
Y_test = data_sets.validation.labels

# 学習ログ(validation)をコンソールに表示させる
tf.logging.set_verbosity(tf.logging.INFO)
validation_metrics = {"accuracy" : MetricSpec(metric_fn = tf.contrib.metrics.streaming_accuracy, prediction_key = "class")}
validation_monitor = learn.monitors.ValidationMonitor(X_test, Y_test, metrics = validation_metrics, every_n_steps = 100)

# 学習実行
classifier = learn.Estimator(model_fn = cnn, model_dir = '/tmp/cnn_log', config = learn.RunConfig(save_checkpoints_secs = 10))
classifier.fit(x = X_train, y = Y_train, steps = 3200, batch_size = 64, monitors = [validation_monitor])


2.ログ(コンソール画面)

ログはこんな感じ↓

...

INFO:tensorflow:global_step/sec: 2.20219
INFO:tensorflow:Starting evaluation at 2018-07-18-07:07:56
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/cnn_log/model.ckpt-3250
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Finished evaluation at 2018-07-18-07:08:06
INFO:tensorflow:Saving dict for global step 3250: accuracy = 0.9898, global_step = 3250, loss = 0.03569727
INFO:tensorflow:Validation (step 3280): accuracy = 0.9898, loss = 0.03569727, global_step = 3250
INFO:tensorflow:loss = 0.0060795164, step = 3280 (40.518 sec)
INFO:tensorflow:Saving checkpoints for 3281 into /tmp/cnn_log/model.ckpt.
INFO:tensorflow:Saving checkpoints for 3316 into /tmp/cnn_log/model.ckpt.
INFO:tensorflow:Saving checkpoints for 3346 into /tmp/cnn_log/model.ckpt.
INFO:tensorflow:Saving checkpoints for 3376 into /tmp/cnn_log/model.ckpt.
INFO:tensorflow:global_step/sec: 2.33117
INFO:tensorflow:Starting evaluation at 2018-07-18-07:08:39
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/cnn_log/model.ckpt-3376
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Finished evaluation at 2018-07-18-07:08:51
INFO:tensorflow:Saving dict for global step 3376: accuracy = 0.9894, global_step = 3376, loss = 0.03725804
INFO:tensorflow:Validation (step 3380): accuracy = 0.9894, loss = 0.03725804, global_step = 3376
INFO:tensorflow:loss = 0.005827286, step = 3380 (45.278 sec)
INFO:tensorflow:Saving checkpoints for 3381 into /tmp/cnn_log/model.ckpt.
INFO:tensorflow:Saving checkpoints for 3415 into /tmp/cnn_log/model.ckpt.
INFO:tensorflow:Saving checkpoints for 3449 into /tmp/cnn_log/model.ckpt.
INFO:tensorflow:Saving checkpoints for 3479 into /tmp/cnn_log/model.ckpt.
INFO:tensorflow:Loss for final step: 0.0017738211.

参考:
qiita.com

deepage.net

Stanford University CS231n: Convolutional Neural Networks for Visual Recognition