TF 中的 indexing 和 slicing

在 numpy 中进行 indexing 相当方便,可以传递一个下标数组,然后根据这些下标取出数组中对应位置的元素(或者切片)。在 TF 中,其实也有对应的操作,主要依赖 tf.gather 和 tf.gather_nd 这两个函数。但是这两个函数的文档(尤其是 tf.gather_nd 的文档)很迷,当你不懂它是怎么执行的时候,看文档也看不懂;当你懂了的时候,就会发现他说的确实都是对的= = 本文主要讲解这两个函数的用法。


PART I: tf.gather


tf.gather 的原型为:

tf.gather(
    params,
    indices,
    validate_indices=None,
    name=None,
    axis=0
)

只能沿某一个轴(用 axis 参数指定,默认为第 0 个轴,后文也只对 axis=0 的情况进行讲解)对张量 params 进行索引。因为只沿一个轴进行索引,所以下标是标量。


使用 tf.gather 对单个下标进行索引的示例如下:


1、假设 params 是一维张量

[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23]

,那么以 1 为下标,取出的就是 tf.gather(params, 1) = params[1] = 1;

2、假设 params 是二维张量

[[ 0  1  2  3  4  5  6  7  8  9 10 11]
 [12 13 14 15 16 17 18 19 20 21 22 23]]

,那么以 1 为下标,取出的就是 tf.gather(params, 1) = params[1] = params[1, :] = [12 13 14 15 16 17 18 19 20 21 22 23];

3、假设 params 是三维张量

[[[ 0  1  2  3]
  [ 4  5  6  7]
  [ 8  9 10 11]]

 [[12 13 14 15]
  [16 17 18 19]
  [20 21 22 23]]]

,那么以 1 为下标,取出的就是 tf.gather(params, 1) = params[1] = params[1, :, :] =

 [[12 13 14 15]
  [16 17 18 19]
  [20 21 22 23]]


当然,tf.gather 还可以同时对多个下标进行索引,并把结果组合起来:参数 indices 是一个由下标组成的张量(每个下标均为整数标量)。使用 indices 对 params 进行索引,相当于使用 indices 中的每一个元素 index 对 params 进行索引,最后再把结果拼起来。例如:


1、假设 params 是一维张量

[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23]

,indices 是一维张量 [2, 5],那么 tf.gather(params, [2, 5]) = [params[2], params[5]] = [2, 5];

2、假设 params 是一维张量

[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23]

,indices 是二维张量 [[2, 5], [0, 6]],那么结果相当于 tf.gather(params, [[2, 5], [0, 6]]) = [[params[2], params[5]], [params[0], params[6]]],即 [[2, 5], [0, 6]];

3、假设 params 是二维张量

[[ 0  1  2  3  4  5  6  7  8  9 10 11]
 [12 13 14 15 16 17 18 19 20 21 22 23]]

,indices 是一维张量 [1, 0],那么结果相当于 tf.gather(params, [1, 0]) = [params[1], params[0]] =

[[12 13 14 15 16 17 18 19 20 21 22 23]
[ 0  1  2  3  4  5  6  7  8  9 10 11]]

注意这里结果是二维,这是因为原先的 indices 形如 (2, ),其中的每个位置 i 经过 params[i] 的索引以后变成了一个长为 12 的一维张量,因此结果就变成了形如 (2, 12) 的二维张量;

4、假设 params 是三维张量

[[[ 0  1  2  3]
  [ 4  5  6  7]
  [ 8  9 10 11]]


 [[12 13 14 15]
  [16 17 18 19]
  [20 21 22 23]]]

,indices 是一维张量 [1, 0],那么结果相当于 tf.gather(params, [1, 0]) = [params[1], params[0]] =

[[[12 13 14 15]
  [16 17 18 19]
  [20 21 22 23]]

 [[ 0  1  2  3]
  [ 4  5  6  7]
  [ 8  9 10 11]]]

同样,indices 中的每个元素经过索引都变成了 3*4 的张量,所以最终结果就变成了 2*3*4 的张量。

5、假设 params 是三维张量

[[[ 0  1  2  3]
  [ 4  5  6  7]
  [ 8  9 10 11]]


 [[12 13 14 15]
  [16 17 18 19]
  [20 21 22 23]]]

,indices 是二维张量 [[1, 0], [0, 1]],那么结果相当于 tf.gather(params, [[1, 0], [0, 1]]) = [[params[1], params[0]], [params[1], params[0]]] =

[[[[12 13 14 15]
   [16 17 18 19]
   [20 21 22 23]]

  [[ 0  1  2  3]
   [ 4  5  6  7]
   [ 8  9 10 11]]]


 [[[ 0  1  2  3]
   [ 4  5  6  7]
   [ 8  9 10 11]]

  [[12 13 14 15]
   [16 17 18 19]
   [20 21 22 23]]]]

注:indices 本身是 2*2 的,其中每个元素经过索引都变成了 3*4 的张量,所以最终结果就变成了 2*2*3*4 的张量。



PART II: tf.gather_nd


tf.gather_nd 的原型为:

tf.gather_nd(
    params,
    indices,
    name=None
)

该函数实际上是 tf.gather 的高维推广,可以支持按多个坐标轴同时进行索引。实践中会取前几个坐标轴进行索引;如果需要跳着取某几个轴,可以和 tf.transpose 相结合,把需要索引的轴转到前几个维度上,最后再转回去。


以三维张量 params

[[[ 0  1  2  3]
  [ 4  5  6  7]
  [ 8  9 10 11]]


 [[12 13 14 15]
  [16 17 18 19]
  [20 21 22 23]]]

为例,对单个下标进行索引的操作示例如下(tf.gather_nd 可以按任意 k 个坐标轴进行索引。这里的 params 有三个维度,所以可以按 1 个、2 个、3 个坐标轴进行索引,索引结果分别为 2 维、1 维、0 维张量):


1、以一维下标 [1] (注意和 tf.gather 不同,这里的 [1] 有一层括号)进行索引,结果为 tf.gather_nd(params, [1]) = params[1] = params[1, :, :] =

[[12 13 14 15]
 [16 17 18 19]
 [20 21 22 23]]

2、以二维下标 [1, 2] 进行索引,可以索引前两个坐标轴,结果为 tf.gather_nd(params, [1, 2]) = params[1, 2] = params[1, 2, :] = [20 21 22 23]

3、以三维下标 [1, 2, 3] 进行索引,可以索引前三个坐标轴,结果为 tf.gather_nd(params, [1, 2, 3]) = params[1, 2, 3] = 23。


注意 tf.gather_nd 和 tf.gather 的区别:

  • tf.gather 在索引时直接传一个整数作为下标即可,而 tf.gather_nd 则要用一层括号把整数括起来作为下标(因为他支持按前若干个不定数量的坐标轴进行索引,括号用于告诉 tf.gather_nd 到底索引几个坐标轴)

基于上述区别,假如 indices = [1, 0],那么 tf.gather(params, indices) 不等于 tf.gather_nd(params, indices)。这是因为,前者等于 [params[1], params[0]],即它会把 1 和 0 当成两个下标,分别取出 params[1] 和 params[0],最后加上一层括号返回。结果为:

[[[12 13 14 15]
  [16 17 18 19]
  [20 21 22 23]]

 [[ 0  1  2  3]
  [ 4  5  6  7]
  [ 8  9 10 11]]]

而后者则会把 [1, 0] 看成一个整体,取出 params[1, 0] = params[1, 0, :] = [12 13 14 15]。


假如要用 tf.gather_nd 得到和 tf.gather(params [1, 0]) 相同的效果,那么应该使用 tf.gather_nd(params, [[1], [0]])。原因后面解释。


tf.gather_nd 自然也可以对多个下标同时进行索引。同样,此时需要把 indices 参数看成下标的数组,只不过下标已经不再是单个整数,而是 indices 的最后一维的向量。例如,假设 indices = [[1, 0], [0, 1]],那么在 tf.gather 中,indices 被视为一个2*2 的下标张量,每个下标是一个整数;而在 tf.gather_nd 中,indices 被视为一个长为 2 的下标张量 [ind1, ind2],其中每个下标是长为 2 的数组,即 ind1 = [1, 0],ind2 = [0, 1]。


对于一般的情况,假设 indices 是形如 n_1\times n_2\times \cdots\times n_{s-1}\times n_s 的张量,可以将其视作 n_1\times n_2\times \cdots\times n_{s-1} 个下标,其中每个下标为 n_s 维。假设 params 的 rank(即坐标轴个数)为 r,那么显然应该有 n_s <= r(可以按照 r 个轴中的 1 个、2 个……直到全部 r 个轴进行索引);假设 params 张量的形状为 params.shape,那么索引结果就是形如 params.shape[n_s: ] 的张量(可以对照上面几个对单个下标进行索引的例子来观察)。然后将全部 n_1\times n_2\times \cdots\times n_{s-1} 个下标的索引结果拼起来,最终结果的形状就是 n_1\times n_2\times \cdots\times n_{s-1}\times params.shape[n_s: ]。


以 tf.gather_nd(params, [[1, 0], [0, 1]]) 为例,具体的计算流程相当于分别检索 tf.gather_nd(params, [1, 0])(等于 params[1, 0] = params[1, 0, :] = [12 13 14 15])和 tf.gather_nd(params, [0, 1])(等于 params[0, 1] = params[0, 1, :] = [4 5 6 7]);最后再加一层括号括起来。即:

[[12 13 14 15]
 [ 4  5  6  7]]

明白了上面的例子,tf.gather_nd(params, [[1], [0]]) 和 tf.gather(params, [1, 0]) 的等价性就比较显然了。


其实这篇文章说的就是 tf.gather 和 tf.gather_nd 的文档中的内容,也没有多什么东西。只是直接看 tf.gather_nd 的文档太迷了,希望这篇文章能讲清楚 tf.gather_nd 的用法。如果读者现在再回过头去看 tf.gather_nd 的文档,就会发现,文档说的确实都是对的(扶额= =

编辑于 2018-08-17

文章被以下专栏收录