谷歌WaveNet 源码详解

谷歌WaveNet 源码详解

WaveNet是谷歌deepmind最新推出基于深度学习的语音生成模型。该模型可以直接对原始语音数据进行建模,在 text-to-speech和语音生成任务中效果非常好(详情请参见:谷歌WaveNet如何通过深度学习方法来生成声音?)。本文将对WaveNet的tensorflow实现的源码进行详解(本文解析的源代码为github上的ibab发布的采用tensorflow实现的WaveNet,github链接:ibab发布的wavenet源码)。

本文的结构如下:一,wavenet结构介绍;二,源代码详解;三,总结

1,wavenet结构介绍

wavenet采用了扩大卷积和因果卷积的方法,让感受野随着网络深度增加而成倍增加,可以对原始语音数据进行建模。(详情请参见:谷歌WaveNet如何通过深度学习方法来生成声音?

2,源码详解


2.1 概况

好了,废话不多说,直接上源码。github下载下来的文件夹如图所示:


其中,关键的文件为train.py,generate.py和wavenet文件夹。train.py为训练代码,generate.py为生成代码。wavenet文件夹包括了所需的模型,语音读取,以及其它功能类和方法。wavenet文件夹包含文件如图所示:

2.2 train.py解析

让我们正式开始wavenet之旅把。首先看看train.py。train.py包括了一系列参数,模型保存(save())/加载(load())方法以及main()方法。

2.2.1 一系列参数。

BATCH_SIZE = 1 #batchsize
DATA_DIRECTORY = './VCTK-Corpus' #训练数据路径
LOGDIR_ROOT = './logdir'#log路径
CHECKPOINT_EVERY = 50#每个多少轮check
NUM_STEPS = 4000#每一轮训练步数
LEARNING_RATE = 0.02#学习速率
WAVENET_PARAMS = './wavenet_params.json'#模型参数
STARTED_DATESTRING = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())#初始生成种子
SAMPLE_SIZE = 100000#样本大小
L2_REGULARIZATION_STRENGTH = 0#l2 正则强度
SILENCE_THRESHOLD = 0.3
EPSILON = 0.001
ADAM_OPTIMIZER = 'adam'#adam优化器
SGD_OPTIMIZER = 'sgd'#sgd优化器
SGD_MOMENTUM = 0.9#学习动量

2.2.2 模型保存/加载方法

该部分代码很简单,其中关键函数为tensorflow的保存和加载模型函数。

saver.save(sess, checkpoint_path, global_step=step)#保存模型参数

saver.restore(sess, ckpt.model_checkpoint_path)#加载模型

2.2.3 main方法

main()包含了训练的主要内容:一,读取wavenet模型参数;二,建立tensorflow的coordinator;三, 从VCTK 数据集生成input;四,建模wavenet模型;五,训练并保存模型

def main():
    #获取参数
    args = get_arguments()
    ...
    #略过部分不重要的代码

    #读取wavenet模型参数
    with open(args.wavenet_params, 'r') as f:
        wavenet_params = json.load(f)

    # 建立coordinator.
    coord = tf.train.Coordinator()

    # 从VCTK 数据集生成input.
    with tf.name_scope('create_inputs'):
        # Allow silence trimming to be skipped by specifying a threshold near
        # zero.
        silence_threshold = args.silence_threshold if args.silence_threshold > \
                                                      EPSILON else None
        #此处采用了audio_reader.py中的AudioReader类,后面将对该类进行详解
        reader = AudioReader(
            args.data_dir,
            coord,
            sample_rate=wavenet_params['sample_rate'],
            sample_size=args.sample_size,
            silence_threshold=args.silence_threshold)
        audio_batch = reader.dequeue(args.batch_size)#传入batch数据

    #建立网络,使用model.py的WaveNetModel类.后面将对该类进行详解。
    net = WaveNetModel(
        batch_size=args.batch_size,
        dilations=wavenet_params["dilations"],
        filter_width=wavenet_params["filter_width"],
        residual_channels=wavenet_params["residual_channels"],
        dilation_channels=wavenet_params["dilation_channels"],
        skip_channels=wavenet_params["skip_channels"],
        quantization_channels=wavenet_params["quantization_channels"],
        use_biases=wavenet_params["use_biases"],
        scalar_input=wavenet_params["scalar_input"],
        initial_filter_width=wavenet_params["initial_filter_width"])
    #是否使用L2正则
    if args.l2_regularization_strength == 0:
        args.l2_regularization_strength = None
    #计算loss
    loss = net.loss(audio_batch, args.l2_regularization_strength)
    #选择使用sgd还是adam优化器
    if args.optimizer == ADAM_OPTIMIZER:
        optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
    elif args.optimizer == SGD_OPTIMIZER:
        optimizer = tf.train.MomentumOptimizer(learning_rate=args.learning_rate,
                                               momentum=args.sgd_momentum)
    else:
        # This shouldn't happen, given the choices specified in argument
        # specification.
        raise RuntimeError('Invalid optimizer option.')
    trainable = tf.trainable_variables()
    optim = optimizer.minimize(loss, var_list=trainable)#优化,最小化loss函数

    # 将日志数据写入 TensorBoard.
    writer = tf.train.SummaryWriter(logdir)
    writer.add_graph(tf.get_default_graph())
    run_metadata = tf.RunMetadata()
    summaries = tf.merge_all_summaries()

    #开始 session
    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))
    init = tf.initialize_all_variables()
    sess.run(init)

    # Saver for storing checkpoints of the model.
    #保存模型
    saver = tf.train.Saver(var_list=tf.trainable_variables())

    try:
        saved_global_step = load(saver, sess, restore_from)
        if is_overwritten_training or saved_global_step is None:
            # The first training step will be saved_global_step + 1,
            # therefore we put -1 here for new or overwritten trainings.
            saved_global_step = -1

    except:
        print("Something went wrong while restoring checkpoint. "
              "We will terminate training to avoid accidentally overwriting "
              "the previous model.")
        raise

    #此处采用了tensorflow的线程和队列的方法
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    reader.start_threads(sess)

    ...
    ...
    #此处略过部分代码
    finally:
        if step > last_saved_step:
            save(saver, sess, logdir, step)
        coord.request_stop()
        coord.join(threads)

main方法中使用了audio_reader.py和model.py中的类,让我们进一步探究。


2.3 audio_reader.py解析

audio_reader.py包含了四个方法(find_files(),load_generic_audio(),load_vctk_audio(),trim_silence())和一个类 AudioReader()。

四个方法中,需要关注一下的是trim_silence()方法。该方法是去除音频数据开始和结尾的空白段。


#去除音频数据开始和结尾的空白段。
def trim_silence(audio, threshold):
    '''Removes silence at the beginning and end of a sample.'''
    energy = librosa.feature.rmse(audio)#获取音频energy(能量)
    frames = np.nonzero(energy > threshold)#大于阈值
    indices = librosa.core.frames_to_samples(frames)[1]

    # Note: indices can be an empty array, if the whole audio was silence.
    return audio[indices[0]:indices[-1]] if indices.size else audio[0:0]

让我们再看一下 AudioReader()类。该类包含四个方法,功能是将预处理好的音频数据打包成tensorflow queue。

class AudioReader(object):
    '''Generic background audio reader that preprocesses audio files
    and enqueues them into a TensorFlow queue.'''

    def __init__(self,
                 audio_dir,
                 coord,
                 sample_rate,
                 sample_size=None,
                 silence_threshold=None,
                 queue_size=256):
        self.audio_dir = audio_dir#训练文件路径
        self.sample_rate = sample_rate#采样率
        self.coord = coord
        self.sample_size = sample_size#样本大小
        self.silence_threshold = silence_threshold#阈值,低于多少就为零
        self.threads = []#线程
        self.sample_placeholder = tf.placeholder(dtype=tf.float32, shape=None)
        
        #队列初始化
        self.queue = tf.PaddingFIFOQueue(queue_size,
                                         ['float32'],
                                         shapes=[(None, 1)])
        #队列入栈
        self.enqueue = self.queue.enqueue([self.sample_placeholder])

        #队列出站
    def dequeue(self, num_elements):
        output = self.queue.dequeue_many(num_elements)
        return output

    #主线程
    def thread_main(self, sess):
        buffer_ = np.array([])#缓冲
        stop = False
        # Go through the dataset multiple times
        while not stop:
            iterator = load_generic_audio(self.audio_dir, self.sample_rate)#加载音频数据
            for audio, filename in iterator:
                if self.coord.should_stop():
                    stop = True
                    break
                if self.silence_threshold is not None:
                    # Remove silence
                    audio = trim_silence(audio[:, 0],self.silence_threshold)#去除音频开始和结尾的空白
                    if audio.size == 0:
                        print("Warning: {} was ignored as it contains only "
                              "silence. Consider decreasing trim_silence "
                              "threshold, or adjust volume of the audio."
                              .format(filename))

                if self.sample_size:
                    # Cut samples into fixed size pieces
                    buffer_ = np.append(buffer_, audio)
                    while len(buffer_) > self.sample_size:
                        piece = np.reshape(buffer_[:self.sample_size], [-1, 1])
                        sess.run(self.enqueue,
                                 feed_dict={self.sample_placeholder: piece})
                        buffer_ = buffer_[self.sample_size:]
                else:
                    sess.run(self.enqueue,
                             feed_dict={self.sample_placeholder: audio})#将读取的音频数据压入队列

    #开始线程             
    def start_threads(self, sess, n_threads=1):
        for _ in range(n_threads):
            thread = threading.Thread(target=self.thread_main, args=(sess,))
            thread.daemon = True  # Thread will close when parent quits.
            thread.start()
            self.threads.append(thread)
        return self.threads
这部分有参考价值的是tensorflow的线程和队列的使用(详情请参见tensorflow官方中文文档)。队列的使用包括几个步骤:一,建立队列;二,初始化队列;三,队列的入栈和出栈

(此图摘自tensorflow官方中文文档

线程的使用步骤:一,建立Coordinator对象;二,建立线程;将线程加入Coordinator运行。

# 线程体:循环执行,直到`Coordinator`收到了停止请求。
# 如果某些条件为真,请求`Coordinator`去停止其他线程。
def MyLoop(coord):
  while not coord.should_stop():
    ...do something...
    if ...some condition...:
      coord.request_stop()

# 建立Coordinator
coord = Coordinator()

# Create 10 threads that run 'MyLoop()'
#建立线程
threads = [threading.Thread(target=MyLoop, args=(coord)) for i in xrange(10)]

# Start the threads and wait for all of them to stop.
#开始线程
for t in threads: t.start()
coord.join(threads)

(以上代码摘自tensorflow官方中文文档

2.4 model.py解析

该部分是源码最精华的部分,包括了建立网络模型和语音生成器相关函数。因为内容较多,同时生成器相关函数和建立网络模型相关函数大同小异,本文只详解网络模型建立相关函数。

model.py包含两个函数一个类。其中,create_variable()和create_bias_variable()功能分别为创建/初始化权值和bias的函数,很简单。

WaveNetModel类中关键函数有:创建变量函数_create_variables(),创建因果卷积函数_create_causal_layer(),创建扩大卷积函数_create_dilation_layer(),建立网络模型函数_create_network()以及loss()函数。本文对一些函数进行了测试,以便观察函数是如何实现相关功能的。下文代码注释中的'测试中',是表示在单独测试该函数功能时设定的值。

2.4.1 创建变量函数_create_variables()

该函数创建了模型建立所需的所有变量(因果/扩大卷积层以及后处理层所需的变量),并存为字典待使用。

2.4.2 因果卷积函数_create_causal_layer()

该函数功能是建立因果卷积。函数调用了ops.py 文件中的causal_conv()函数。让我们看看causal_conv()函数到底干了什么事情。下文代码注释中的'测试中',是表示在单独测试该函数功能时设定的值。从测试可以看出来,该因果卷积的实现方法是采用将输出偏移几步来实现,具体采用的是tf.pad()方法来实现偏移。

def time_to_batch(value, dilation, name=None):
    with tf.name_scope('time_to_batch'):
        #测试中,传入的value的shape为(1,9,1)
        #dilation=4
        shape = tf.shape(value)
        #pad_elements计算为3
        pad_elements = dilation - 1 - (shape[1] + dilation - 1) % dilation
        #padded后的shape为(1,12,1)即在第二个维度后加3个零
        padded = tf.pad(value, [[0, 0], [0, pad_elements], [0, 0]])
        #reshape后的shape为(3,4,1)
        reshaped = tf.reshape(padded, [-1, dilation, shape[2]])
        #转置后的shape为(4,3,1)
        transposed = tf.transpose(reshaped, perm=[1, 0, 2])
        #最后返回的shape为(4,3,1)
        return tf.reshape(transposed, [shape[0] * dilation, -1, shape[2]])


def batch_to_time(value, dilation, name=None):
    with tf.name_scope('batch_to_time'):
        shape = tf.shape(value)
        prepared = tf.reshape(value, [dilation, -1, shape[2]])
        transposed = tf.transpose(prepared, perm=[1, 0, 2])
        #最后返回的是前面time_to_batch的最初输入数值的shape
        #测试中,shape为(1,9,1)
        return tf.reshape(transposed,
                          [tf.div(shape[0], dilation), -1, shape[2]])

def causal_conv(value, filter_, dilation, name='causal_conv'):
    with tf.name_scope(name):
        # Pad beforehand to preserve causality.
        #测试中,filter_width=2
        filter_width = tf.shape(filter_)[0]
        #测试中,dilation设定为4
        #因此,padding为padding=[[0, 0], [4, 0], [0, 0]]
        padding = [[0, 0], [(filter_width - 1) * dilation, 0], [0, 0]]
        #测试中,value的shape为(1,5,1)
        #测试中,padding为在value的第二维度前面加4个零
        #padded的shape变为(1,9,1)
        padded = tf.pad(value, padding)
        if dilation > 1:
            #见time_to_batch函数测试
            #测试中,最后返回来的shape为(4,3,1)
            transformed = time_to_batch(padded, dilation)
            
            conv = tf.nn.conv1d(transformed, filter_, stride=1, padding='SAME')
            #最后返回最开始的shape形式,详情见op测试
            restored = batch_to_time(conv, dilation)
        else:
            restored = tf.nn.conv1d(padded, filter_, stride=1, padding='SAME')
        # Remove excess elements at the end.
        result = tf.slice(restored,
                          [0, 0, 0],
                          [-1, tf.shape(value)[1], -1])
        #最后返回的结果形式和padding后的即padded数据shape一样
        return result

2.4.3 创建扩大卷积_create_dilation_layer()

该函数实现扩大卷积层,同时在该层创建了residual 和skip connection,让模型更快收敛。该层的网络结构如注释中所示。

    def _create_dilation_layer(self, input_batch, layer_index, dilation):
        '''Creates a single causal dilated convolution layer.

        The layer contains a gated filter that connects to dense output
        and to a skip connection:

               |-> [gate]   -|        |-> 1x1 conv -> skip output
               |             |-> (*) -|
        input -|-> [filter] -|        |-> 1x1 conv -|
               |                                    |-> (+) -> dense output
               |------------------------------------|

        Where `[gate]` and `[filter]` are causal convolutions with a
        non-linear activation at the output.
        '''
        variables = self.variables['dilated_stack'][layer_index]

        weights_filter = variables['filter']
        weights_gate = variables['gate']

        #filter卷积
        conv_filter = causal_conv(input_batch, weights_filter, dilation)
        #gate卷积
        conv_gate = causal_conv(input_batch, weights_gate, dilation)

        #是否使用bias
        if self.use_biases:
            filter_bias = variables['filter_bias']
            gate_bias = variables['gate_bias']
            conv_filter = tf.add(conv_filter, filter_bias)
            conv_gate = tf.add(conv_gate, gate_bias)

        #gate和filter共同输出
        out = tf.tanh(conv_filter) * tf.sigmoid(conv_gate)

        # The 1x1 conv to produce the residual output
        #采用1×1卷积实现残差输出
        weights_dense = variables['dense']
        transformed = tf.nn.conv1d(
            out, weights_dense, stride=1, padding="SAME", name="dense")

        # The 1x1 conv to produce the skip output
        #采用1×1卷积实现skip输出
        weights_skip = variables['skip']
        #skip output
        skip_contribution = tf.nn.conv1d(
            out, weights_skip, stride=1, padding="SAME", name="skip")

        if self.use_biases:
            dense_bias = variables['dense_bias']
            skip_bias = variables['skip_bias']
            transformed = transformed + dense_bias
            skip_contribution = skip_contribution + skip_bias

            
        layer = 'layer{}'.format(layer_index)
        #加入summary
        tf.histogram_summary(layer + '_filter', weights_filter)
        tf.histogram_summary(layer + '_gate', weights_gate)
        tf.histogram_summary(layer + '_dense', weights_dense)
        tf.histogram_summary(layer + '_skip', weights_skip)
        if self.use_biases:
            tf.histogram_summary(layer + '_biases_filter', filter_bias)
            tf.histogram_summary(layer + '_biases_gate', gate_bias)
            tf.histogram_summary(layer + '_biases_dense', dense_bias)
            tf.histogram_summary(layer + '_biases_skip', skip_bias)

        #返回skip output和(残差+input)
        return skip_contribution, input_batch + transformed

2.4.4 建立网络模型函数_create_network()

该函数采用前面的_create_dilation_layer()建立网络模型。在因果卷积后面,加上了后续处理层(postprocessing layer)。后续处理层结构为:Perform (+) -> ReLU -> 1x1 conv -> ReLU -> 1x1 conv。


 #建立模型
    def _create_network(self, input_batch):
        '''Construct the WaveNet network.'''
        outputs = []
        current_layer = input_batch

        # Pre-process the input with a regular convolution
        if self.scalar_input:
            initial_channels = 1
        else:
            initial_channels = self.quantization_channels

        #初始层
        current_layer = self._create_causal_layer(current_layer)

        # Add all defined dilation layers.
        #建立dilated层,总共18层
        with tf.name_scope('dilated_stack'):
            for layer_index, dilation in enumerate(self.dilations):
                with tf.name_scope('layer{}'.format(layer_index)):
                    output, current_layer = self._create_dilation_layer(
                        current_layer, layer_index, dilation)
                    outputs.append(output)

        #postprocess层
        with tf.name_scope('postprocessing'):
            # Perform (+) -> ReLU -> 1x1 conv -> ReLU -> 1x1 conv to
            # postprocess the output.
            #创建后续处理层变量
            w1 = self.variables['postprocessing']['postprocess1']
            w2 = self.variables['postprocessing']['postprocess2']
            if self.use_biases:
                b1 = self.variables['postprocessing']['postprocess1_bias']
                b2 = self.variables['postprocessing']['postprocess2_bias']

            tf.histogram_summary('postprocess1_weights', w1)
            tf.histogram_summary('postprocess2_weights', w2)
            if self.use_biases:
                tf.histogram_summary('postprocess1_biases', b1)
                tf.histogram_summary('postprocess2_biases', b2)

            # We skip connections from the outputs of each layer, adding them
            # all up here.
            #将每一层的skip connection输出累加
            total = sum(outputs)
            transformed1 = tf.nn.relu(total)
            conv1 = tf.nn.conv1d(transformed1, w1, stride=1, padding="SAME")
            if self.use_biases:
                conv1 = tf.add(conv1, b1)
            transformed2 = tf.nn.relu(conv1)
            conv2 = tf.nn.conv1d(transformed2, w2, stride=1, padding="SAME")
            if self.use_biases:
                conv2 = tf.add(conv2, b2)

        return conv2

2.4.5 loss()函数

该函数首先将输入语音数据进行\mu -law编码(mu_law_encode())后再使用one-hot编码。loss函数采用的是tf.nn.softmax_cross_entropy_with_logits()。

#损失函数
    def loss(self,
             input_batch,
             l2_regularization_strength=None,
             name='wavenet'):
        '''Creates a WaveNet network and returns the autoencoding loss.

        The variables are all scoped to the given name.
        '''
        with tf.name_scope(name):

            #使用mu-law编码
            input_batch = mu_law_encode(input_batch,
                                        self.quantization_channels)

            #再使用one hot编码
            encoded = self._one_hot(input_batch)
            #如果使用标量输入,则转换成标量
            if self.scalar_input:
                network_input = tf.reshape(
                    tf.cast(input_batch, tf.float32),
                    [self.batch_size, -1, 1])
            else:
                network_input = encoded

            #网络预测输出
            raw_output = self._create_network(network_input)

            with tf.name_scope('loss'):
                # Shift original input left by one sample, which means that
                # each output sample has to predict the next input sample.
                #向左偏移一位,即减去第一位,保证每次是预测下一个输出。
                #测试中,encoded的shape为(1,9,1),比如[0,0,0,0,1~5]
                #shifted后的shape为(1,8,1),比如[0,0,0,1~5]
                shifted = tf.slice(encoded, [0, 1, 0],
                                   [-1, tf.shape(encoded)[1] - 1, -1])
                #加零后,shape重新变为(1,9,1),比如比如[0,0,0,1~5,0]
                shifted = tf.pad(shifted, [[0, 0], [0, 1], [0, 0]])

                #将模型预测转换shape为prediction
                prediction = tf.reshape(raw_output,
                                        [-1, self.quantization_channels])
                #loss函数
                loss = tf.nn.softmax_cross_entropy_with_logits(
                    prediction,
                    tf.reshape(shifted, [-1, self.quantization_channels]))
                reduced_loss = tf.reduce_mean(loss)

                tf.scalar_summary('loss', reduced_loss)

                if l2_regularization_strength is None:
                    return reduced_loss
                else:
                    # L2 regularization for all trainable parameters
                    l2_loss = tf.add_n([tf.nn.l2_loss(v)
                                        for v in tf.trainable_variables()
                                        if not('bias' in v.name)])

                    # Add the regularization term to the loss
                    total_loss = (reduced_loss +
                                  l2_regularization_strength * l2_loss)

                    tf.scalar_summary('l2_loss', l2_loss)
                    tf.scalar_summary('total_loss', total_loss)

                    return total_loss

2.5 generate.py解析

这部分代码用于模型语音生成。有了前面的解析,代码就相对比较简单,略过。github上还有Fast Wavenet,解决了wavenet原文中的语音生成方法的问题是语音生成太慢,有兴趣可以参考。

3,总结

WaveNet结合了因果卷积和扩展卷积方法,让感受野随着模型深度增加而成倍增加。该神经网络模型不仅适用于原始语音数据的生成,也适用于文字生成(tex-wavenet),图像生成(image-wavenet)等,是值得深入研究的一个神经网络模型。

编辑于 2017-01-14