tensorflow模型线上部署

这次写一个比较短的工程部署总结,说是tensorflow线上部署,并没有使用目前推荐的tensorflow serving。一则是该功能只是整个项目中的一个很小的功能点,如果单独为该功能部署tensorflow serving,成本和时间上会比较超标;二则由于公司内部网络环境限制,无法完整顺利获取tensorflow serving需要的依赖包(在没有网络环境下去装tensorflow serving环境的同学都懂的);三则,项目并没有对模型在线学习的需求,只需要模型离线训练后,部署到线上即可。基于上述三点条件,选择了使用tensorflow python开发并训练模型,最后将模型进行序列化。然后在java端(项目主开发语言)调用模型进行在线预测。

前置条件

tensorflow 1.3

所需的tensorflow java jar包:libtensorflow-1.4.1.jar

所需的tensorflow jni 库文件:libtensorflow_jni.so libtensorflow_framework.so

模型序列化

一开始,由于未考虑模型如何移植到java平台的原因,模型训练和保存的时候,还是按照传统的方式,即通过tf.train.Saver()方法得到saver对象,然后调用saver.save方法保存模型,得到模型的checkpoint,meta,data,index四个文件。

上述文件中保存了模型的中间参数,模型当前训练的状态等信息,在某种意义上是动态的,只能通过tensorflow本身的python接口来调用该保存好的模型,并不能跨平台直接使用。其实模型的本质还是一堆权重数据和具体的权重计算流程,因此需要某种机制能固定住模型的权重数据和计算流程,即freeze模型。

saver_predict = tf.train.import_meta_graph(model_config.ckpt_path + model_name)
with tf.Session() as sess:
  if os.path.exists(model_config.ckpt_path):
             print("Restoring Variables from Checkpoint")
             saver_predict.restore(sess, tf.train.latest_checkpoint(model_config.ckpt_path))
  else:
             print("Can't find the checkpoint.going to stop")
             exit()
  output_node_names = [n.name for n in tf.get_default_graph().as_graph_def().node]
  frozen_graph_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, output_node_names)
  tf.train.write_graph(frozen_graph_def, model_config.ckpt_path, target_model_path, as_text=False)

上述代码的作用是对使用import_meta_graph读取训练好的模型,然后获取当前计算图中的所有节点,并将这些节点中的所有权重参数转化为常量,最后将这些常量保存到一个pb文件中,pb文件即protobuf,是 Google 推出的一种二进制数据交换格式。能够用于跨平台间的数据交换。上述代码实际使用时,会有问题,在java侧使用java的api调用模型时,会出现如下错误:

Invalid argument: Input 0 of node XXXXXXXXXXX/BatchNorm/cond/AssignMovingAvg_1/Switch was passed float from XXXXXXXXXXX/BatchNorm/moving_variance:0 incompatible with expected float_ref.

经过一番google,找到了这个github的issue,貌似并没有得到很好的解决,推测问题原因是在freeze模型的权重参数时,对一些tensor的data type处理有问题。问题链接如下,有兴趣的同学可以看看,我最后换了另外一种方式来做。

Error loading a frozen graph ( float incompatible with float_ref ) · Issue #161 · davidsandberg/facenetgithub.com图标

最终,我使用的是另外一种方法,即在模型训练完时,直接对模型的权重参数序列化,保存为pb文件。代码如下:

# 将模型序列化保存
builder = tf.saved_model.builder.SavedModelBuilder(model_config.pb_path.format(epoch))
builder.add_meta_graph_and_variables(sess, ["XXX"])
builder.save()

调用SavedModelBuilder得到builder对象,然后将session中的所有图结构和权重参数存入到builder中,最后保存为pb文件。

Java调用模型在线预测

这里主要使用的是tensorflow的Java版本api,说是Java版本,其实功能很局限,并没有模型训练方面的功能,所幸它有读取模型和数据,然后在线预测的功能,因此可以适用于当前场景。在部署时,有几点需要注意一下:

1.最基础的事情,当然是记得将tensorflow jar包加入到build-path中。

2.在linux上进行部署时,需要将两个so文件部署到项目工程的java build path中,因为我们的工程中的path包含/usr/lib/,因此可以将两个so文件放到这个路径下。两个so文件主要是用于tensorflow 上层的api与底层操作系统native library进行通信的接口。

下面主要介绍一下如何使用java来调用模型预测。

首先列出用到 两个主要的操作对象:

SavedModelBundle tensorflowModelBundle
Session tensorflowSession

SavedModelBundle为java侧与pb文件接口的对象,能够读取pb文件。而Session对应的是tensorflow中的会话对象,java中,tensorflow的预测操作也是需要在一个会话中进行的。

tensorflowModelBundle = SavedModelBundle.load(tensorflowModelPath, "XXXX");
tensorflowSession = tensorflowModelBundle.session();

然后就是构造输入模型的数据了。同python中的情况类似,java侧的模型接收的数据类型必须为tensor类型,因此需要将数据转化为tensor。因此要用到Tensor对象的create方法来生成Tensor,假设当前我们处理好后的数值型输入数据为一个二维数组testvec:

Tensor input = Tensor.create(testvec);

当然如果有其他输入的话,也要都转化为Tensor,简单说就是原来模型中feed_dict中的所有输入都要转化为Tensor对象。

然后就是调用session,输入需要的数据,然后调用具体的计算节点输出结果:

Tensor output = tensorflowSession.runner().feed("input",input).feed(XX,XX).fetch("computation node_name").run().get(0);

这行代码有几个注意点:

1、feed方法返回的仍然是Runner对象,这个机制使得我们可以链式调用feed方法,将所有需要喂入模型的数据装载。

2、Runner对象的fetch方法是定位到具体的计算图中的计算节点(tensor),这个与python中调用模型预测的方法类似,需要在构造计算图的时候,对模型输出样本预测概率(或者logits)的tensor指定名称。

3、最后的get()方法则是获取返回的结果,这里我输入了参数0,是因为run()方法默认返回的是一个List<Tensor>,因为有可能有的需求会调用多个计算节点,因此会返回多个tensor,但是此时我只需要得到一个tensor结果,因此获取List中的第一个元素。

上述方法返回的是一个Tensor对象,为了输出结果,需要将它转化为原始的二维数据格式:

float[][] resultValues = (float[][]) out.copyTo(new float[1][1]);

调用Tensor的copyTo方法,能够将Tensor转化为指定数据格式的数组。之所以是二维数组,是因为我们的输入数据是二维数组,虽然一次一般是预测一个样本,但为了开发的普适性,统一处理为二维数组,数组存储的就是该样本的预测概率。

最后有一点需要注意一下,在使用完模型后,需要将所有创建的Tensor关闭,销毁资源,当然这个是开发的一个好习惯,能够避免资源的泄露和低效利用。

out.close();
input.close();

发布于 2019-01-28

文章被以下专栏收录