一起读Bert文本分类代码 (pytorch篇 二)

一起读Bert文本分类代码 (pytorch篇 二)

Bert是去年google发布的新模型,打破了11项纪录,关于模型基础部分就不在这篇文章里多说了。这次想和大家一起读的是huggingface的pytorch-pretrained-BERT代码examples里的文本分类任务run_classifier。

关于源代码可以在huggingface的github中找到。

huggingface/pytorch-pretrained-BERTgithub.com图标


在上一篇文章中我介绍完了数据预处理部分

周剑:一起读Bert文本分类代码 (pytorch篇 一)zhuanlan.zhihu.com图标


接上一篇文章,在这篇文章中我会和大家一起读模型部分。

继续接着看主函数部分:

    train_examples = None
    num_train_steps = None
    if args.do_train:
        train_examples = processor.get_train_examples(args.data_dir)
        num_train_steps = int(
            len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs)

    # Prepare model
    model = BertForSequenceClassification.from_pretrained(args.bert_model,
              cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank),
              num_labels = num_labels)

这段代码先是将训练数据导入train_example中,然后根据训练数据的总数算出需要多少个steps。

我们打开pytorch_pretrained_bert.modeling.py找到BertForSequenceClassification类。我先整理了BertForSequenceClassification类中调用关系,如下图所示。本篇文章中,我会和大家一起读BertForSequenceClassification类,PreTrainedBertModel类和BertForSequenceClassification类中调用的BertModel的代码。


BertForSequenceClassification类,代码如下:

class BertForSequenceClassification(PreTrainedBertModel):
    """
参数:
    config:指定的bert模型的预训练参数
    num_labels:分类的类别数量
输入:
    input_ids:训练集,torch.LongTensor类型,shape是[batch_size, sequence_length]
    token_type_ids:可选项,当训练集是两句话时才有的。
    attention_mask:可选项,当使用mask才有,可参考原论文。
    labels:数据标签,torch.LongTensor类型,shape是[batch_size]
输出:
    如果labels不是None(训练时):输出的是分类的交叉熵
    如果labels是None(评价时):输出的是shape为[batch_size, num_labels]估计值
    """
    # Already been converted into WordPiece token ids
    input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
    input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
    token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
    config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
        num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
    num_labels = 2
    model = BertForSequenceClassification(config, num_labels)
    logits = model(input_ids, token_type_ids, input_mask)
    ```
    """
    def __init__(self, config, num_labels=2):
        super(BertForSequenceClassification, self).__init__(config)
        self.num_labels = num_labels
        self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, num_labels)
        self.apply(self.init_bert_weights)

    def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
        _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            return loss
        else:
            return logits

我们从forward函数看,数据先输入BertModel中,然后进行dropout,之后是一个用作分类的Linear层。也就是说分类任务只是在bert的模型基础上加了一个线形层。


我们可以看到上面的BertForSequenceClassification类是继承于PreTrainedBertModel的子类,我们再来看看PreTrainedBertModel类的代码:

class PreTrainedBertModel(nn.Module):
    """ An abstract class to handle weights initialization and
        a simple interface for dowloading and loading pretrained models.
    """
    def __init__(self, config, *inputs, **kwargs):
        super(PreTrainedBertModel, self).__init__()
        if not isinstance(config, BertConfig):
            raise ValueError(
                "Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
                "To create a model from a Google pretrained model use "
                "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
                    self.__class__.__name__, self.__class__.__name__
                ))
        self.config = config

    def init_bert_weights(self, module):
        """ Initialize the weights.
        """
        if isinstance(module, (nn.Linear, nn.Embedding)):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
        elif isinstance(module, BertLayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()

    @classmethod
    def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None, *inputs, **kwargs):
        """
        参数:
            预训练模型名称,可选: 
                    . `bert-base-uncased`
                    . `bert-large-uncased`
                    . `bert-base-cased`
                    . `bert-large-cased`
                    . `bert-base-multilingual-uncased`
                    . `bert-base-multilingual-cased`
                    . `bert-base-chinese
        """
        if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP:
            archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name]
        else:
            archive_file = pretrained_model_name
        # redirect to the cache, if necessary
        try:
            resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
        except FileNotFoundError:
            logger.error(
                "Model name '{}' was not found in model name list ({}). "
                "We assumed '{}' was a path or url but couldn't find any file "
                "associated to this path or url.".format(
                    pretrained_model_name,
                    ', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
                    archive_file))
            return None
        if resolved_archive_file == archive_file:
            logger.info("loading archive file {}".format(archive_file))
        else:
            logger.info("loading archive file {} from cache at {}".format(
                archive_file, resolved_archive_file))
        tempdir = None
        if os.path.isdir(resolved_archive_file):
            serialization_dir = resolved_archive_file
        else:
            # Extract archive to temp dir
            tempdir = tempfile.mkdtemp()
            logger.info("extracting archive file {} to temp dir {}".format(
                resolved_archive_file, tempdir))
            with tarfile.open(resolved_archive_file, 'r:gz') as archive:
                archive.extractall(tempdir)
            serialization_dir = tempdir
        # Load config
        config_file = os.path.join(serialization_dir, CONFIG_NAME)
        config = BertConfig.from_json_file(config_file)
        logger.info("Model config {}".format(config))
        # Instantiate model.
        model = cls(config, *inputs, **kwargs)
        if state_dict is None:
            weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
            state_dict = torch.load(weights_path)

        old_keys = []
        new_keys = []
        for key in state_dict.keys():
            new_key = None
            if 'gamma' in key:
                new_key = key.replace('gamma', 'weight')
            if 'beta' in key:
                new_key = key.replace('beta', 'bias')
            if new_key:
                old_keys.append(key)
                new_keys.append(new_key)
        for old_key, new_key in zip(old_keys, new_keys):
            state_dict[new_key] = state_dict.pop(old_key)

        missing_keys = []
        unexpected_keys = []
        error_msgs = []
        # copy state_dict so _load_from_state_dict can modify it
        metadata = getattr(state_dict, '_metadata', None)
        state_dict = state_dict.copy()
        if metadata is not None:
            state_dict._metadata = metadata

        def load(module, prefix=''):
            local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
            module._load_from_state_dict(
                state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
            for name, child in module._modules.items():
                if child is not None:
                    load(child, prefix + name + '.')
        load(model, prefix='' if hasattr(model, 'bert') else 'bert.')
        if len(missing_keys) > 0:
            logger.info("Weights of {} not initialized from pretrained model: {}".format(
                model.__class__.__name__, missing_keys))
        if len(unexpected_keys) > 0:
            logger.info("Weights from pretrained model not used in {}: {}".format(
                model.__class__.__name__, unexpected_keys))
        if tempdir:
            # Clean up temp dir
            shutil.rmtree(tempdir)
        return model

其中这段代码中的PRETRAINED_MODEL_ARCHIVE_MAP还是和前面的字典一样,是预训练模型参数的下载地址。这段代码的大概意思就是下载到内存并加载或者直接加载本地的预训练好的bert模型参数。


我们再返回BertForSequenceClassification类中,在init里有一句self.bert = BertModel(config)。我们再找到BertModel这个类:

class BertModel(PreTrainedBertModel):
    参数:
        config: bert模型的参数
    输入和BertForSequenceClassification基本一样,少了标签,多了一个output_all_encoded_layers用来控制是否输出每一层的状态   
    输出是shape为 (encoded_layers, pooled_output)的元组
        `encoded_layers`: 通过output_all_encoded_layers参数控制:
            - `output_all_encoded_layers=True`: 输出每个层encoded-hidden-states的列表,每层状态的类型是FloatTensor,shape是 [batch_size, sequence_length, hidden_size]
            - `output_all_encoded_layers=False`: 只输出最后一层状态,shape是[batch_size, sequence_length, hidden_size],
        `pooled_output`: 用来训练两句话类型的bert模型的输出,FloatTensor类型,shape是[batch_size, hidden_size] 
    """
    def __init__(self, config):
        super(BertModel, self).__init__(config)
        self.embeddings = BertEmbeddings(config)
        self.encoder = BertEncoder(config)
        self.pooler = BertPooler(config)
        self.apply(self.init_bert_weights)

    def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True):
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
        extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        embedding_output = self.embeddings(input_ids, token_type_ids)
        encoded_layers = self.encoder(embedding_output,
                                      extended_attention_mask,
                                      output_all_encoded_layers=output_all_encoded_layers)
        sequence_output = encoded_layers[-1]
        pooled_output = self.pooler(sequence_output)
        if not output_all_encoded_layers:
            encoded_layers = encoded_layers[-1]
        return encoded_layers, pooled_output

从forward函数,我们可以看到在这个BertModel是先经过attention_mask将训练集进行mask处理。然后,经过embeddings词向量层加载预训练好的词向量。之后是encoder层。最后是pooler层用来训练两句话的bert模型。


BertModel这段代码中调用了BertEmbeddings,BertEncoder和 BertPooler。这些我会在我的下一篇文章中和大家一起阅读。

周剑:一起读Bert文本分类代码 (pytorch篇 三)zhuanlan.zhihu.com图标周剑:一起读Bert文本分类代码 (pytorch篇 四)zhuanlan.zhihu.com图标周剑:一起读Bert文本分类代码 (pytorch篇 五)zhuanlan.zhihu.com图标周剑:一起读Bert文本分类代码 (pytorch篇 六)zhuanlan.zhihu.com图标

编辑于 2019-02-09

文章被以下专栏收录