使用Keras快速搭建深度学习模型测试最新fashion-mnist数据集

使用Keras快速搭建深度学习模型测试最新fashion-mnist数据集

天清天清

keras交流群:523412399

github地址:QuantumLiu/fashion-mnist-demo-by-Keras

fashion-mnist简介

近日,一个名为fashion-mnist的图像分类数据集火遍机器学习圈。截止9月1日,在短短7天内就获得了2k+的star。

zalandoresearch/fashion-mnist



FashionMNIST是一个替代MNIST手写数字集的图像数据集。 它是由Zalando(一家德国的时尚科技公司)旗下的研究部门提供。其涵盖了来自10种类别的共7万个不同商品的正面图片。FashionMNIST的大小、格式和训练集/测试集划分与原始的MNIST完全一致。60000/10000的训练测试数据划分,28x28的灰度图片。你可以直接用它来测试你的机器学习和深度学习算法性能,且不需要改动任何的代码。

众所周知,mnist是深度学习/图像分类领域的入门demo和最受欢迎的标准数据集,经常用来测试网络有效性。

然而mnist数据集也有自己的缺点,fashion-mnist正是为了克服这些缺点而生。

写给专业的机器学习研究者
我们是认真的。取代MNIST数据集的原因由如下几个:
MNIST太简单了。 很多深度学习算法在测试集上的准确率已经达到99.6%!不妨看看我们基于scikit-learn上对经典机器学习算法的评测 和这段代码: "Most pairs of MNIST digits can be distinguished pretty well by just one pixel"(翻译:大多数MNIST只需要一个像素就可以区分开!) MNIST被用烂了。 参考:"Ian Goodfellow wants people to move away from mnist"(翻译:Ian Goodfellow希望人们不要再用MNIST了。) MNIST数字识别的任务不代表现代机器学习。 参考:"François Cholle: Ideas on MNIST do not transfer to real CV" (翻译:在MNIST上看似有效的想法没法迁移到真正的机器视觉问题上。)

使用尝鲜

搭建模型

看到了新数据集当然忍不住要尝鲜了~~

为了做先吃螃蟹的人,我们需要快速搭建一个图片分类模型模型进行训练。

我的首选当然是真爱keras

Keras是一个高层神经网络API,Keras由纯Python编写而成并基TensorflowTheano以及CNTK后端。Keras 为支持快速实验而生,能够把你的idea迅速转换为结果,如果你有如下需求,请选择Keras:
简易和快速的原型设计(keras具有高度模块化,极简,和可扩充特性)
支持CNN和RNN,或二者的结合
无缝CPU和GPU切换

使用keras搭建一个类似于vgg16网络的CNN模型非常的简洁优雅:

def vgg_fm(input_shape):
    input_tensor=Input(shape=input_shape)
    x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv1')(input_tensor)
    x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv2')(x)
    x = MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x)

    # Block 2
    x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv1')(x)
    x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv2')(x)
    x = MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(x)

    # Block 3
    x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv1')(x)
    x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv2')(x)
    x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv3')(x)
    x = MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(x)

    # Block 4
    x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv1')(x)
    x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv2')(x)
    x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv3')(x)
    x = MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(x)

    # Classification block
    x = Flatten(name='flatten')(x)
    x = Dense(4096, activation='relu', name='fc1')(x)
    x = Dense(4096, activation='relu', name='fc2')(x)
    x = Dropout(0.5)(x)
    x = Dense(10, activation='softmax', name='predictions')(x)
    return Model(inputs=[input_tensor],outputs=[x]

keras相当于是Tensorflow/theano这样的张量计算图框架的高层封装,这个函数返回一个Model对象,包含一个完整的网络计算图,同时拥有一些用于快速训练的方法。

Model对象需要调用compile方法来完损失成函数、优化器等的指定,之后使用fit方法训练。

model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['acc'])
model.fit(x=train_x,y=train_y,validation_data=(test_x,test_y),batch_size=batch_size,epochs=100)

keras在训练时,传入的data和label是numpy数组的实例,不需要转换成其他数据类型,十分方便。而且做分类任务时,label支持传入整数类型的类别ID,不一定是one-hot向量,可以节约内存。

根据fashion-mnist的官方文档,我们可以很方便的将数据集读取为numpy数组。

import mnist_reader
X_train, y_train = mnist_reader.load_mnist('data/fashion', kind='train')
X_test, y_test = mnist_reader.load_mnist('data/fashion', kind='t10k')

完整代码

import sys
import numpy as np
from vgg_fm import vgg_fm
from generators import read_data,reshape
from manager import GPUManager
from keras.callbacks import ModelCheckpoint
from callback import TargetStopping
if __name__=='__main__':
    gm=GPUManager()
    kwargs=dict(zip(['mode','version','batch_size'],sys.argv[1:]))
    mode,version,batch_size=list(map(lambda kd:kwargs.get(kd[0],kd[1]),zip(['mode','version','batch_size'],['vgg','v1',256])))
    batch_size=int(batch_size)
    model_name=mode+'_'+version
    with gm.auto_choice():
        (train_x,train_y),(test_x,test_y)=read_data('train'),read_data('test')
        train_x,test_x,train_y,test_y=reshape(train_x,False),reshape(test_x,False),np.expand_dims(train_y,-1),np.expand_dims(test_y,-1)
        input_shape=train_x.shape[1:]
        model=vgg_fm(input_shape)
        model.summary()
        model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['acc'])
        model.fit(x=train_x,y=train_y,validation_data=(test_x,test_y),batch_size=batch_size,epochs=100,
                  callbacks=[TargetStopping(filepath=model_name+'.h5',monitor='val_acc',mode='max',target=0.94),
                             ModelCheckpoint(filepath=model_name+'.h5',save_best_only=True,monitor='val_acc')])

实验结果

最终测试集准确率大概是在93.5%左右,而原始mnist使用同样的模型打到97以上应该很轻松,使用说fashion-mnist着实有一定的挑战性~

「真诚赞赏,手留余香」
还没有人赞赏,快来当第一个赞赏的人吧!
文章被以下专栏收录
8 条评论
推荐阅读