AI见闻
首发于AI见闻

【NLP】BERT中文实战踩坑

终于用上了bert,踩了一些坑,和大家分享一下。


我主要参考了奇点机智的文章,用bert做了两个中文任务:文本分类和相似度计算。这两个任务都是直接用封装好的run_classifer,py,另外两个没有仔细看,用到了再补充。

1. DataProcessor

Step1:写好自己的processor,照着例子写就可以,一定要shuffle!!!

Step2:加到main函数的processors字典里

2. Early Stopping

Step1:建一个hook

early_stopping_hook = tf.contrib.estimator.stop_if_no_decrease_hook(
            estimator=estimator,
            metric_name='eval_loss',
            max_steps_without_decrease=FLAGS.max_steps_without_decrease,
            eval_dir=None,
            min_steps=0,
            run_every_secs=None,
            run_every_steps=FLAGS.save_checkpoints_steps)

Step2:加到estimator.train里

estimator.train(input_fn=train_input_fn, max_steps=num_train_steps, hooks=[early_stopping_hook])

3. Train and Evaluate

需要用tensorboard查看训练曲线的话比较好

Step1:创建train和eval的spec,这里需要把early stopping的hook加到trainSpec

train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=num_train_steps,
                                                hooks=[early_stopping_hook])
eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn, throttle_secs=60)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)

4. Batch size

默认Eval和Predict的batch size都很小,记得改一下

5. 模型保存

存储为SavedModel,参考 bert_serving

# 加到开头
flags.DEFINE_bool("do_export", False, "Whether to export the model.")
flags.DEFINE_string("export_dir", None, "The dir where the exported model will be written.")

# 加到结尾
def serving_input_fn():
    label_ids = tf.placeholder(tf.int32, [None], name='label_ids')
    input_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='input_ids')
    input_mask = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='input_mask')
    segment_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='segment_ids')
    input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
        'label_ids': label_ids,
        'input_ids': input_ids,
        'input_mask': input_mask,
        'segment_ids': segment_ids,
    })()
    return input_fn

if FLAGS.do_export:
        estimator._export_to_tpu = False
        estimator.export_savedmodel(FLAGS.export_dir, serving_input_fn)


6. 模型速度优化

  • 需要移动设备部署的用TensorFlow Lite的post_training_quantization函数
  • 有GPU的用NAVIDIA的TensorRT
  • 其他的Tensorflow自带optimize_for_inference函数,参考 bert-as-service

7.BERT服务

试了一下bert-as-service,如果不要求后端语言的话用这个性能就很好了。如果要用java,可以参考bert-as-service/server/graph.py的代码把模型保存为pb文件,用java调用。


hi all!最近终于有了自己的公众号,叫NLPCAB,本来想叫LAB,但觉得和一个人能撑起实验室我就上天了,所以取了谐音CAB,有些可爱并且意味深长?之后会努力和Andy发干货,也希望各位同学投稿学习笔记~

编辑于 2019-10-22

文章被以下专栏收录

    分享一些炼丹经验、在读的paper,主要是NLP相关,欢迎喜欢分享的同行投稿