CTC实现——compute ctc loss(1)

本文收录在无痛的机器学习第二季目录

上一次我们介绍了关于CTC的一些基本问题,下面我们还是落到实处,来介绍一个经典的CTC实现代码——来自百度的warp-ctc。这次百度没有作恶……

CTC:前向计算例子

这里我们直接使用warp-ctc中的变量进行分析。我们定义T为RNN输出的结果的维数,这个问题的最终输出维度为alphabet_size。而ground_truth的维数为L。也就是说,RNN输出的结果为alphabet_size*T的结果,我们要将这个结果和1*L这个向量进行对比,求出最终的Loss。

我们要一步一步地揭开这个算法的细节……当然这个算法的实现代码有点晦涩……

我们的第一步要顺着test_cpu.cpp的路线来分析代码。第一步我们就是要解析small_test()中的内容。也就是做前向计算,计算对于RNN结果来说,对应最终的ground_truth——t的label的概率。

这个计算过程可以用动态规划的算法求解。我们可以用一个变量来表示动态规划的中间过程,它就是:

\alpha^T_i:表示在RNN计算的时间T时刻,这一时刻对应的ground_truth的label为第i个下标的值t[i]的概率。

这样的表示有点抽象,我们用一个实际的例子来讲解:

RNN结果:[R_1,R_2,R_3,R_4],这里的每一个变量都对应一个列向量。

ground_truth:[g_1,g_2,g_3]

那么\alpha^2_1表示R_2的结果对应着g_1的概率,当然与此同时,前面的结果也都合理地对应完成。

从上面的结果我们可以看出,如果R_2的结果对应着g_1,那么R_1的结果也必然对应着g_1。所以前面的结果是确定的。然而对于其他的一些情况来说,我们的转换存在着一定的不确定性。

CTC:前向计算具体过程

我们还是按照上面的例子进行计算,我们把刚才的例子搬过来:

RNN结果:[R_1,R_2,R_3,R_4],这里的每一个变量都对应一个列向量。

ground_truth:[g_1,g_2,g_3]

alphabet:[g_0(blank),g_1,g_2,g_3]

按照上面介绍的计算方法,第一步我们先做ground_truth的状态扩展,于是我们就把长度从3扩展到了7,现在的ground_truth变成了:

[blank,g_1,blank,g_2,blank,g_3,blank]

我们的RNN结果长度为4,也就是说我们会从上面的7个ground_truth状态中进行转移,并最终转移到最终状态。理论上利用动态规划的算法,我们需要计算4*7=28个中间结果。好了,下面我们用P^T_i表示RNN的第T时刻状态为ground_truth中是第i个位置的概率。

那么我们就开始计算了:

T=1时,我们只能选择g_1和blank,所以这一轮我们终结状态只可能落在0和1上。所以第一轮变成了:

[P^1_0,P^1_1,0,0,0,0,0]

T=2时,我们可以继续选择g_1,我们同时也可以选择g_2,还可以选择g_1g_2之间的blank,所以我们可以进一步关注这三个位置的概率,于是我们将其他的位置的概率设为0。[0,(P^1_0 +P^1_1)P^2_1,P^1_1P^2_2,P^1_1P^2_3,0,0,0]

T=3时,留给我们的时间已经不多了,我们还剩2步,要走完整个旅程,我们只能选择g_2g_3以及它们之间的空格。于是乎我们关心的位置又发生了变化:

[0,0,0, (P^1_1P^2_2+P^1_1P^2_3)P^3_3, P^1_1P^2_3P^3_4, P^1_1P^2_3P^3_5, 0]

是不是有点看晕了?没关系,因为还剩最后一步了。下面是最后一步,因为最后一步我们必须要到g_3以及它后面的空格了,所以我们的概率最终计算也就变成了:

[0,0,0, 0,0, ((P^1_1P^2_2+P^1_1P^2_3)P^3_3+P^1_1P^2_2P^3_4+P^1_1P^2_2P^3_3)P^4_5, P^1_1P^2_3P^3_5P^4_6]

好吧,最终的结果我们求出来了,实际上这就是通过时间的推移不断迭代求解出来的。关于迭代求解的公式这里就不再赘述了。我们直接来看一张图:

(注:T=2时的第一个红色方框应该有一条线连接T=3的第一个红色方框,感谢@WXB506 指出)

于是乎我们从这个计算过程中发现一些问题:

首先是一个相对简单的问题,我们看到在计算过程中我们发现了大量的连乘。由于每一个数字都是浮点数,那么这样连乘下去,最终数字有可能非常小而导致underflow。所以我们要将这个计算过程转到对数域上。这样我们就将其中的乘法转变成了加法。但是原本就是加法的计算呢?比方说我们现在计算了loga和logb,我们如何计算log(a+b)呢,这里老司机给出了解决方案,我们假设两个数中a>b,那么有

log(a+b)=log(a(1+\frac{b}{a}))=loga+log(1+\frac{b}{a})

=loga+log(1+exp(log(\frac{b}{a})))=loga+log(1+exp(logb - loga))

这样我们就利用了loga和logb计算出了log(a+b)来。

另外一个问题就是,我们发现在刚才的计算过程当中,对于每一个时间段,我们实际上并不需要计算每一个ground-truth位置的概率信息,实际上只要计算满足某个条件的某一部分就可以了。所以我们有没有希望在计算前就规划好这条路经,以保证我们只计算最相关的那些值呢?

如何控制计算的数量?

不得不说,这一部分warp-ctc写得实在有点晦涩,当然也可能是我在这方面的理解比较渣。我们这里主要关注两个部分——一个是数据的准备,一个是最终的数据的使用。

在介绍数据准备之前,我们先简单说一下这部分计算的大概思路。我们用两个变量start和end表示我们需要计算的状态的起止点,在每一个时间点,我们要更新start和end这两个变量。然后我们更新start和end之间的概率信息。这里我们先要考虑一个问题,start和end的更新有什么规律?

为了简化思考,我们先假设ground_truth中没有重复的label,我们的大脑瞬间得到了解放。好了,下面我们就要给出代码中的两个变量——

T:表示RNN结果中的维度

S/2:ground_truth的维度(S表示了扩展blank之后的维度)

基本上具备一点常识,我们就可以知道T>=S/2。什么?你觉得有可能出现T<S/2的情况?兄弟,这种见鬼的事情如果发生,你难道要我们把RNN的结果拆开给你用?臣妾不太能做得到啊……

好了,既然接受了上面的事实,那么我们就来举几个例子看看:

我们假设T=3,S/2=3,那么说白了,它们之间的对应关系是一一对应,说白了这就和blank位置没啥关系了。在T=1时,我们要转移到第一个结果,T=2,我们要转移到第二个结果……

那么我们还有别的情况么?下回更精彩。

广告时间

更多精彩尽在《深度学习轻松学:核心算法与视觉实践》

编辑于 2017-11-22

文章被以下专栏收录