×

PyTorch教程15.10之预训练BERT

消耗积分:0 | 格式:pdf | 大小:0.15 MB | 2023-06-05

分享资料个

借助15.8 节中实现的 BERT 模型和15.9 节中从 WikiText-2 数据集生成的预训练示例 ,我们将在本节中在 WikiText-2 数据集上预训练 BERT。

import torch
from torch import nn
from d2l import torch as d2l
from mxnet import autograd, gluon, init, np, npx
from d2l import mxnet as d2l

npx.set_np()

首先,我们将 WikiText-2 数据集加载为用于屏蔽语言建模和下一句预测的小批量预训练示例。批量大小为 512,BERT 输入序列的最大长度为 64。请注意,在原始 BERT 模型中,最大长度为 512。

batch_size, max_len = 512, 64
train_iter, vocab = d2l.load_data_wiki(batch_size, max_len)
batch_size, max_len = 512, 64
train_iter, vocab = d2l.load_data_wiki(batch_size, max_len)
Downloading ../data/wikitext-2-v1.zip from https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip...

15.10.1。预训练 BERT

原始 BERT 有两个不同模型大小的版本 Devlin et al. , 2018基础模型(BERTBASE) 使用 12 层(Transformer 编码器块),具有 768 个隐藏单元(隐藏大小)和 12 个自注意力头。大模型(BERTLARGE) 使用 24 层,有 1024 个隐藏单元和 16 个自注意力头。值得注意的是,前者有 1.1 亿个参数,而后者有 3.4 亿个参数。为了便于演示,我们定义了一个小型 BERT,使用 2 层、128 个隐藏单元和 2 个自注意力头。

net = d2l.BERTModel(len(vocab), num_hiddens=128,
          ffn_num_hiddens=256, num_heads=2, num_blks=2, dropout=0.2)
devices = d2l.try_all_gpus()
loss = nn.CrossEntropyLoss()
net = d2l.BERTModel(len(vocab), num_hiddens=128, ffn_num_hiddens=256,
          num_heads=2, num_blks=2, dropout=0.2)
devices = d2l.try_all_gpus()
net.initialize(init.Xavier(), ctx=devices)
loss = gluon.loss.SoftmaxCELoss()

在定义训练循环之前,我们定义了一个辅助函数 _get_batch_loss_bert给定训练示例的碎片,此函数计算掩码语言建模和下一句预测任务的损失。请注意,BERT 预训练的最终损失只是掩码语言建模损失和下一句预测损失的总和。

#@save
def _get_batch_loss_bert(net, loss, vocab_size, tokens_X,
             segments_X, valid_lens_x,
             pred_positions_X, mlm_weights_X,
             mlm_Y, nsp_y):
  # Forward pass
  _, mlm_Y_hat, nsp_Y_hat = net(tokens_X, segments_X,
                 valid_lens_x.reshape(-1),
                 pred_positions_X)
  # Compute masked language model loss
  mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) *\
  mlm_weights_X.reshape(-1, 1)
  mlm_l = mlm_l.sum() / (mlm_weights_X.sum() + 1e-8)
  # Compute next sentence prediction loss
  nsp_l = loss(nsp_Y_hat, nsp_y)
  l = mlm_l + nsp_l
  return mlm_l, nsp_l, l
#@save
def _get_batch_loss_bert(net, loss, vocab_size, tokens_X_shards,
             segments_X_shards, valid_lens_x_shards,
             pred_positions_X_shards, mlm_weights_X_shards,
             mlm_Y_shards, nsp_y_shards):
  mlm_ls, nsp_ls, ls = [], [], []
  for (tokens_X_shard, segments_X_shard, valid_lens_x_shard,
     pred_positions_X_shard, mlm_weights_X_shard, mlm_Y_shard,
     nsp_y_shard) in zip(
    tokens_X_shards, segments_X_shards, valid_lens_x_shards,
    pred_positions_X_shards, mlm_weights_X_shards, mlm_Y_shards,
    nsp_y_shards):
    # Forward pass
    _, mlm_Y_hat, nsp_Y_hat = net(
      tokens_X_shard, segments_X_shard, valid_lens_x_shard.reshape(-1),
      pred_positions_X_shard)
    # Compute masked language model loss
    mlm_l = loss(
      mlm_Y_hat.reshape((-1, vocab_size)), mlm_Y_shard.reshape(-1),
      mlm_weights_X_shard.reshape((-1, 1)))
    mlm_l = mlm_l.sum() / (mlm_weights_X_shard.sum() + 1e-8)
    # Compute next sentence prediction loss
    nsp_l = loss(nsp_Y_hat, nsp_y_shard)
    nsp_l = nsp_l.mean()
    mlm_ls.append(mlm_l)
    nsp_ls.append(nsp_l)
    ls.append(mlm_l + nsp_l)
    npx.waitall()
  return mlm_ls, nsp_ls, ls

调用上述两个辅助函数,以下 函数定义了在 WikiText-2 ( ) 数据集上train_bert预训练 BERT ( ) 的过程训练 BERT 可能需要很长时间。与在函数中指定训练的时期数不同 (参见第 14.1 节),以下函数的输入指定训练的迭代步数。nettrain_itertrain_ch13num_steps

def train_bert(train_iter, net, loss, vocab_size, devices, num_steps):
  net(*next(iter(train_iter))[:4])
  net = nn.DataParallel(net, device_ids=devices).to(devices[0])
  trainer = torch.optim.Adam(net.parameters(), lr=0.01)
  step, timer = 0, d2l.Timer()
  animator = d2l.Animator(xlabel='step', ylabel='loss',
              xlim=[1, num_steps], legend=['mlm', 'nsp'])
  # Sum of masked language modeling losses, sum of next sentence prediction
  # losses, no. of sentence pairs, count
  metric = d2l.Accumulator(4)
  num_steps_reached = False
  while step < num_steps and not num_steps_reached:
    for tokens_X, segments_X, valid_lens_x, pred_positions_X,\
      mlm_weights_X, mlm_Y, nsp_y in train_iter:
      tokens_X = tokens_X.to(devices[0])
      segments_X = segments_X.to(devices[0])
      valid_lens_x = valid_lens_x.to(devices[0])
      pred_positions_X = pred_positions_X.to(devices[0])
      mlm_weights_X = mlm_weights_X.to(devices[0])
      mlm_Y, nsp_y = mlm_Y.to(devices[0]), nsp_y.to(devices[0])
      trainer.zero_grad()
      timer.start()
      mlm_l, nsp_l, l = _get_batch_loss_bert(
        net, loss, vocab_size, tokens_X, segments_X, valid_lens_x,
        pred_positions_X, mlm_weights_X, mlm_Y, nsp_y)
      l.backward()
      trainer.step()
      metric.add(mlm_l, nsp_l, tokens_X.shape[0], 1)
      timer.stop()
      animator.add(step + 1,
             (metric[0] / metric[3], metric[1] / metric[3]))
      step += 1
      if step == num_steps:
        num_steps_reached = True
        break

  print(f'MLM loss {metric[0] / metric[3]:.3f}, '
     f'NSP loss {metric[1] / metric[3]:.3f}')
  print(f'{metric[2] / timer.sum():.1f} sentence pairs/sec on '
     f'{str(devices)}')
def train_bert(train_iter, net, loss, vocab_size, devices, num_steps):
  trainer = gluon.Trainer(net.collect_params(), 'adam',
              {'learning_rate': 0.01})
  step, timer = 0, d2l.Timer()
  animator = d2l.Animator(xlabel='step', ylabel='loss',
              xlim=[1, num_steps], legend=['mlm', 'nsp'])
  # Sum of masked language modeling losses, sum of next sentence prediction
  # losses, no. of sentence pairs, count
  metric = d2l.

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉

评论(0)
发评论

下载排行榜

全部0条评论

快来发表一下你的评论吧 !