基于Transformer的语音合成系统

基于Transformer的语音合成系统

Google于2017年提出的Transformer结构颠覆了整个自然语言处理领域,至此,处理序列的模型不仅仅限于CNN和RNN,又多了Transformer,并且Transformer在许多方面相较于CNN和RNN都有很大的提高。

Transformer最先用于神经机器翻译领域,之后又在许多领域证明了其强大的性能。本文将根据论文Neural Speech Synthesis with Transformer Network简单地阐述将Transformer用于语音合成(TTS)的方法,并给出基于Pytorch的实现:

xcmyz/Transformer-TTSgithub.com图标

注:GitHub上已经有很好的实现了,且本实现基于神经机器翻译的Transformer实现

Neural Speech Synthesis with Transformer Network这篇论文虽然有8页纸,但是大部分章节都用来复述Original Transformer和Tacotron2的方法,作者把Transformer和Tacotron2融合,形成了Transformer-TTS。下面是模型的图示:

由图可知,模型的主体还是Original Transformer,只是在输入阶段和输出阶段为了配合语音数据的特性做了改变。首先是Encoder的Input阶段,先将text逐字符转化为编号,方便Embedding,然后进入Encoder PreNet,这层网络由一个Embedding layer和三层卷积层构成,转化为512维的向量后,进入Transformer Encoder。其次是Transformer的Decoder部分,分为Input和Output。Input通过一个PreNet,将80维的梅尔声谱图转化为512维向量,这里的PreNet是一个三层的全连接网络(个人认为论文中应当解释一下为什么Encoder的PreNet是用卷积设计的,而Decoder的PreNet由全连接网络就可以解决问题);Output部分与Tacotron2的设计完全一致。

Tacotron2结构

论文作者对这一模型做了很多的实验,总的来说,训练时期的速度大大提高,加快了2到3倍,生成语音的质量也好于传统RNN结构模型(存疑,复现版本仅仅能做到效果相接近,可能是作者的调参技艺比较高超)。

基于Transformer的TTS模型已是现在主流的End-to-End TTS系统的baseline,它的实现必不可少,而且因为Transformer本身优异的结构,也能大大加快实验的速度。

放一下部分代码:

class TransformerTTS(nn.Module):
    """ TTS model based on Transformer """

    def __init__(self, num_mel=80, embedding_size=512):
        super(TransformerTTS, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
        self.postnet = PostNet()
        self.stop_linear = Linear(embedding_size, 1, w_init='sigmoid')
        self.mel_linear = Linear(embedding_size, num_mel)

    def forward(self, src_seq, src_pos, tgt_seq, tgt_pos, mel_tgt, return_attns=False):
        encoder_output = self.encoder(src_seq, src_pos)
        decoder_output = self.decoder(
            tgt_seq, tgt_pos, src_seq, encoder_output[0], mel_tgt)
        decoder_output = decoder_output[0]

        mel_output = self.mel_linear(decoder_output)
        mel_output_postnet = self.postnet(mel_output) + mel_output

        stop_token = self.stop_linear(decoder_output)
        stop_token = stop_token.squeeze(2)

        return mel_output, mel_output_postnet, stop_token

发布于 2019-05-25

文章被以下专栏收录