我们给你推荐一种TensorFlow模型格式

我们给你推荐一种TensorFlow模型格式

简介

TensorFlow的模型格式有很多种,针对不同场景可以使用不同的格式,只要符合规范的模型都可以轻易部署到在线服务或移动设备上,这里简单列举一下。

  • Checkpoint: 用于保存模型的权重,主要用于模型训练过程中参数的备份和模型训练热启动。
  • GraphDef:用于保存模型的Graph,不包含模型权重,加上checkpoint后就有模型上线的全部信息。
  • ExportModel:使用exportor接口导出的模型文件,包含模型Graph和权重可直接用于上线,但官方已经标记为deprecated推荐使用SavedModel。
  • SavedModel:使用saved_model接口导出的模型文件,包含模型Graph和权限可直接用于上线,TensorFlow和Keras模型推荐使用这种模型格式。
  • FrozenGraph:使用freeze_graph.py对checkpoint和GraphDef进行整合和优化,可以直接部署到Android、iOS等移动设备上。
  • TFLite:基于flatbuf对模型进行优化,可以直接部署到Android、iOS等移动设备上,使用接口和FrozenGraph有些差异。

模型格式

目前建议TensorFlow和Keras模型都导出成SavedModel格式,这样就可以直接使用通用的TensorFlow Serving服务,模型导出即可上线不需要改任何代码。不同的模型导出时只要指定输入和输出的signature即可,其中字符串的key可以任意命名只会在客户端请求时用到,可以参考下面的代码示例。

注意,目前使用tf.py_func()的模型导出后不能直接上线,模型的所有结构建议都用op实现。

TensorFlow模型导出

import os
import tensorflow as tf
from tensorflow.python.saved_model import builder as saved_model_builder
from tensorflow.python.saved_model import (
    signature_constants, signature_def_utils, tag_constants, utils)
from tensorflow.python.util import compat

model_path = "model"
model_version = 1
model_signature = signature_def_utils.build_signature_def(
    inputs={
        "keys": utils.build_tensor_info(keys_placeholder),
        "features": utils.build_tensor_info(inference_features)
    },
    outputs={
        "keys": utils.build_tensor_info(keys_identity),
        "prediction": utils.build_tensor_info(inference_op),
        "softmax": utils.build_tensor_info(inference_softmax),
    },
    method_name=signature_constants.PREDICT_METHOD_NAME)
export_path = os.path.join(compat.as_bytes(model_path), compat.as_bytes(str(model_version)))
legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
 
builder = saved_model_builder.SavedModelBuilder(export_path)
builder.add_meta_graph_and_variables(
    sess, [tag_constants.SERVING],
    clear_devices=True,
    signature_def_map={
        signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
        model_signature,
    },
    legacy_init_op=legacy_init_op)
 
builder.save() 

Keras模型导出

import os
import tensorflow as tf
from tensorflow.python.util import compat
 
def export_savedmodel(model):
  model_path = "model"
  model_version = 1
  model_signature = tf.saved_model.signature_def_utils.predict_signature_def(
      inputs={'input': model.input}, outputs={'output': model.output})
  export_path = os.path.join(compat.as_bytes(model_path), compat.as_bytes(str(model_version)))
 
  builder = tf.saved_model.builder.SavedModelBuilder(export_path)
  builder.add_meta_graph_and_variables(
      sess=K.get_session(),
      tags=[tf.saved_model.tag_constants.SERVING],
      clear_devices=True,
      signature_def_map={
          tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
          model_signature
      })
  builder.save()

SavedModel模型结构

使用TensorFlow的API导出SavedModel模型后,可以检查模型的目录结构如下,然后就可以直接使用开源工具来加载服务了。


模型上线

部署在线服务

使用HTTP接口可参考 tobegit3hub/simple_tensorflow_serving

使用gRPC接口可参考 tensorflow/serving

部署离线设备

部署到Android可参考 medium.com/@tobe_ml/all

部署到iOS可参考 zhuanlan.zhihu.com/p/33

发布于 2018-03-12

文章被以下专栏收录