CTC实现——compute ctc loss(2)

本文收录在无痛的机器学习第二季目录

如何控制计算的数量?cont.

好,废话少说我们书接上回。不明真相的小朋友先看这个:

下面我们假设T=4,S/2=3,好玩的地方来了。T比S/2多一个,也就是说我们允许冗余出现了,那么我们可能的形式也就变多了。我们可以增加一个blank,我们也可以在没有label位置原地打一轮酱油。选择更多,欢乐更多。

虽然选择变多,但是着并不意味着我们可以选择任意一种状态转移的方式,至少:

  • 在T=2时,我们至少要转移到第一个结果
  • 在T=3时,我们至少要转移到第二个结果
  • 在T=4时,兄弟我们准备下车了

这其实就是对start的限制。源代码中有这样一句话:

int remain = (S / 2) + repeats - (T - t);

这里我们先忽略repeats,那么remain这个变量其实是在计算label数量和剩余时间的差。如果用这样的语言来表达刚才的那个问题,我们语言就变成这个样子:

  • 当时间还剩4轮时(包括第4轮),我们在哪都无所谓(实际上是从T=1开始计算的)
  • 当时间还剩3轮时(包括第3轮),我们至少要转移到第一个结果(index=1)
  • 当时间还剩2轮时(包括第2轮),我们至少要转移到第二个结果(index=3)
  • 当时间还剩1轮时(包括第1轮),我们至少要转移到第三个结果(index=5)

好了,这里我们看出其中的含义了。我们再啰嗦一下,看看这些变量随T的变化情况:

  • T=1,remain=0,start+=1
  • T=2,remain=1,start+=2
  • T=3,remain=2,start+=2

现在我们已经十分清楚了,当remain>=0时,start都要向前走,限制我们计算前面状态的概率,因为这些概率已经没有意义了。下面的代码也是这样描述的:

if(remain >= 0)
    start += s_inc[remain];

那么这个s_inc是什么东西?它就是我们需要提前准备好的计算量。我们知道经过扩充的label序列中,所有的非空label都处在奇数的index上,而填充的blank都处在偶数的index上(我们是0-based的计算方法,matlab选手请退散……),所以对于上面的问题,当start=0时,下一步我们会从0跳到1,此后我们会从1到3,3到5,跳转的步数都是2,所以基于这个思路,我们就可以把s_inc这个数组生成出来。当然,我们的前提是没有重复。下面我们会说重复的问题的。

我们上面说了这么多,重点把start的变化介绍清楚了。下面我们来看看end。其实end的原理也类似,我们还是用刚才的废话套路来介绍站在end视角的世界:

  • 在T=1时,我们最多能到第一个结果
  • 在T=2时,我们最多能转移到第二个结果
  • 在T=3时,我们最多能转移到第三个结果
  • 在T=4时,我们已经掌握了整个世界……oh yeah

好了,可以看出end的变化形式,每个时刻end都可以+2,直到到达最后一个非blank的label,end变成了+1,然后end就不用动了,等着start动就可以了……(怎么感觉有点污?天哪……)

那么end变化的条件是什么呢?

if(t <= (S / 2) + repeats)
    end += e_inc[t - 1];

我们还是忽略repeats,那么就十分清楚了,如果当前时刻小于等于label数,那么尽管前进,如果大于了,基本上也就到头了,这时候end就不用动了。

好了,前面我们终于说完了简单模式下start和end的移动规律,下面我们来看看带重复模式下的变化方法。

重复,重复

重复会带来什么样的变化呢?说白了如果有重复的label出现,那么两个连续重复的label中间就要至少出现一个blank。换句话说,每出现一个重复,我们的S/2就要加一,于是我们再看一眼这两个计算公式:

int remain = (S / 2) + repeats - (T - t);
if(remain >= 0)
    start += s_inc[remain];
if(t <= (S / 2) + repeats)
    end += e_inc[t - 1];

我们把repeats和S/2归到一起,这时候就能看明白了。

同理,在计算s_inc和e_inc的时候,由于有repeats的存在,它们从过去的+2变成了两个+1。也就是说先从label跳到blank,再跳到下一个label。这样就可以解释s_inc和e_inc的初始化策略了:

int e_counter = 0;
int s_counter = 0;

s_inc[s_counter++] = 1;

int repeats = 0;

for (int i = 1; i < L; ++i) {
    if (labels[i-1] == labels[i]) {
        s_inc[s_counter++] = 1;
        s_inc[s_counter++] = 1;
        e_inc[e_counter++] = 1;
        e_inc[e_counter++] = 1;
        ++repeats;
    }
    else {
        s_inc[s_counter++] = 2;
        e_inc[e_counter++] = 2;
    }
}
e_inc[e_counter++] = 1;

好了,到此我们才算把CTC中compute ctc loss这部分介绍完了。教科书上的一个公式看着简单,落实到代码就似乎充满了trick。希望看懂了这个计算的你大脑没有阵亡。

广告时间

更多精彩尽在《深度学习轻松学:核心算法与视觉实践》

编辑于 2017-11-22

文章被以下专栏收录