前端AI之路: KerasJS初探

前端AI之路: KerasJS初探

推荐先下载项目,直接运行起来看看效果:
地址:
eeandrew/react-kerasjsgithub.com图标

简介

Keras是一款非常流行的深度学习模型开发框架,基于python,语法简洁,封装程度高,只需十几行代码就可以构建一个深度神经网络。

Keras.js是一个可以在浏览器中运行深度神经网络的JS框架,支持CPU,GPU计算。区别于Keras,Keras.js只能运行已经调试好的模型,无法进行模型训练。

KerasJS开发流程如下,首先使用Keras开发训练神经网络,将神经网络模型和数据导出为文件,KerasJS在浏览器端加载此文件,这样才能进行预测。

KerasJS开发流程

模型

借鉴这篇文章

阿里云云栖社区:Keras快速上手——打造个人的第一个“圣诞老人”图像分类模型zhuanlan.zhihu.com图标

,开发一个识别圣诞老人的神经网络。本文不涉及Keras的开发细节,感兴趣的同学可以去原文查看。这里直接给出python代码

def build_model():
    model = models.Sequential()
    model.add(layers.Conv2D(20,(5,5),activation='relu',input_shape=(128,128,3)))
    model.add(layers.MaxPooling2D(pool_size=(2,2),strides=(2,2)))
    model.add(layers.Conv2D(50,(5,5),activation='relu',padding='same'))
    model.add(layers.MaxPooling2D(pool_size=(2,2),strides=(2,2)))
    model.add(layers.Flatten())
    model.add(layers.Dense(500,activation='relu'))
    model.add(layers.Dense(1, activation='sigmoid'))
    model.compile(optimizer=optimizers.RMSprop(lr=2e-5),
                  loss='binary_crossentropy',
                  metrics=['acc'])
    return model

数据

标注数据是AI模型的原料,数据搜集特别是图片搜集是前端可以介入的一个环节。笔者基于React,开发了一款Chrome图片批量下载插件GetThemAll,方便我们进行标记图片搜集。

GetThemAll

安装好插件后,去谷歌图片搜索“santa”, 使用插件标记不需要的图片,然后下载到本地的santa文件夹,通过谷歌图片可以搜集到400张圣诞老人的图片。

接着我们再下载一些非圣诞老人的图片,搜索“object”,同样的使用GetThemAll插件下载大约400张图片到本地的non_santa文件夹中。

除了训练数据集,我们还需要一个测试数据集用来衡量模型的泛化能力。在本地新建一个test文件夹,把刚刚准备好的训练集里面的最后100张圣诞老人图片移到test文件夹下的santa文件中,同样的,移动100张非圣诞老人图片到non_stanta文件中。这样,你可以得到如下的本地图片集:

本地图片集

有了标记数据,我们就可以进行模型训练啦。具体的训练过程请见pyton代码,这里直接给出训练的结果,蓝点表示训练数据集准确率,蓝线表示测试数据集准备率,模型有着明显的High Variance问题,不过这个bug留给深度学习的专家们解决吧,这里就假设这个模型可用。

模型准确率

迁移

上一步训练出的模型keras_santa.h5(h5是文件后缀,和HTML5没啥关系)不能直接给KerasJS使用,需要通过KerasJS提供的转换工具转换后,方可被KerasJS加载解析。

./encoder.py keras_santa.h5

转换后,得到了keras_santa.bin文件,20M左右,这个文件包含了神经网络结构和所有参数,可以被KerasJS加载。

KerasJS

通过上面的步骤,我们得到了一个训练完成的CNN神经网络以及全部参数,这个网络结构和参数全部保存在keras_santa.bin文件中。接下来,我们只需要在浏览器中复原上面的神经网络,然后就可以开始做预测啦。

使用webpack配合React,搭建一套简单的开发环境。做好了基础工作,就可以开始第一步开发,加载神经网络模型文件keras_santa.bin:

const model = new KerasJS.Model({
   filepath: 'http://localhost:3000/keras_santa.bin',
   gpu: false
})
//KerasJS提供模型加载进度接口,考虑到模型文件体积非常大,这个接口会经常用到
model.events.on('loadingProgress', (progress) => {
      this.setState({
        loadingtitle: '模型加载',
        progress: parseInt(progress)
      })
})

使用上面的模型做预测前,需要将数据转化成模型能够接受的数据格式。这个圣诞老人网络需要数据输入格式为(128,128,3),也即是图片需要为128x128分辨率,只能包含RGB三个分量。

输入数据格式

借助canvas,可以实现图片分辨率转换:

_updateImageSrc(imgid) {
    const ctx = this.refs.canvas.getContext('2d');
    const imgdom = document.createElement('img');
    imgdom.src = `http://localhost:3000/${imgid}.jpeg`
    this.setState({
      prediction:0
    })
    imgdom.onload = ()=>{
      ctx.drawImage(imgdom,0,0,128,128)
      const imagedata = ctx.getImageData(0,0,128,128)
      const processeddata = ImageDataUtils.preprocess(imagedata)
      setTimeout(()=>{
        this.doPrediction(processeddata)
      },100);
    }
  }

注意preprocess方法,通过canvas获取到的图片资源包含了rgba四个维度,prepross返回这4个维度中的前3个维度,也即rgb,同时将数据标准化:

export default class ImageDataUtils {
  static preprocess(imageData) {
    const {
      width,
      height,
      data
    } = imageData;
    const dataTensor = ndarray(new Float32Array(data),[width,height,4])
    const dataProcessedTensor = ndarray(new Float32Array(width*height*3),[width,height,3])
    //从[0,255]转化到[0,1]
    ops.divseq(dataTensor,255)
    //获取R数据
    ops.assign(dataProcessedTensor.pick(null,null,0),dataTensor.pick(null,null,0))
    //获取G数据
    ops.assign(dataProcessedTensor.pick(null,null,1),dataTensor.pick(null,null,1))
    //获取B数据
    ops.assign(dataProcessedTensor.pick(null,null,2),dataTensor.pick(null,null,2))
    const preprocessedData = dataProcessedTensor.data
    return preprocessedData
  }   
}

最后,使用上面返回的数据做预测

async doPrediction(imagedata) {
    if(!this.model) return;
    const inputname = this.model.inputLayerNames[0]
    const inputdata = {[inputname]: imagedata}
    const prediction = await this.model.predict(inputdata)
    this.setState({
      prediction: prediction.output[0]
    })
  }
预测Demo

思考

可以看到,KerasJS在预测过程中,整个页面无法响应用户操作。这是因为神经网络计算过程中占用了大量CPU资源,从而致使页面卡顿。下一篇文章中,我们将介绍如何使用WebGL,将计算过程转移到GPU,达到实现前端高性能计算的目的。

相关资源

  1. Image classification with Keras and deep learning,Adrain Rosebrock
  2. GetThemAll, eeandrew
  3. React KerasJS, eeandrew

编辑于 2018-01-29

文章被以下专栏收录

    关注前端前沿技术,探寻业界深邃思想。https://qianduan.group 欢迎微信/微博搜索『前端外刊评论』,关注我们。欢迎给本专栏投稿,原作译作不限,要求:质量高!如果愿意尝试从事前端技术相关的书籍的编写或翻译工作,请私信外刊君。