在本章前面的部分中,我们为 SNLI 数据集上的自然语言推理任务(如第 16.4 节所述)设计了一个基于注意力的架构(第16.5节)。现在我们通过微调 BERT 重新审视这个任务。正如16.6 节所讨论的 ,自然语言推理是一个序列级文本对分类问题,微调 BERT 只需要一个额外的基于 MLP 的架构,如图 16.7.1所示。
在本节中,我们将下载预训练的小型 BERT 版本,然后对其进行微调以在 SNLI 数据集上进行自然语言推理。
16.7.1。加载预训练的 BERT
我们已经在第 15.9 节和第 15.10 节中解释了如何在 WikiText-2 数据集上预训练 BERT (请注意,原始 BERT 模型是在更大的语料库上预训练的)。如15.10 节所述,原始 BERT 模型有数亿个参数。在下文中,我们提供了两个版本的预训练 BERT:“bert.base”与需要大量计算资源进行微调的原始 BERT 基础模型差不多大,而“bert.small”是一个小版本方便演示。
预训练的 BERT 模型都包含一个定义词汇集的“vocab.json”文件和一个预训练参数的“pretrained.params”文件。我们实现以下load_pretrained_model
函数来加载预训练的 BERT 参数。
def load_pretrained_model(pretrained_model, num_hiddens, ffn_num_hiddens,
num_heads, num_blks, dropout, max_len, devices):
data_dir = d2l.download_extract(pretrained_model)
# Define an empty vocabulary to load the predefined vocabulary
vocab = d2l.Vocab()
vocab.idx_to_token = json.load(open(os.path.join(data_dir, 'vocab.json')))
vocab.token_to_idx = {token: idx for idx, token in enumerate(
vocab.idx_to_token)}
bert = d2l.BERTModel(
len(vocab), num_hiddens, ffn_num_hiddens=ffn_num_hiddens, num_heads=4,
num_blks=2, dropout=0.2, max_len=max_len)
# Load pretrained BERT parameters
bert.load_state_dict(torch.load(os.path.join(data_dir,
'pretrained.params')))
return bert, vocab
def load_pretrained_model(pretrained_model, num_hiddens, ffn_num_hiddens,
num_heads, num_blks, dropout, max_len, devices):
data_dir = d2l.download_extract(pretrained_model)
# Define an empty vocabulary to load the predefined vocabulary
vocab = d2l.Vocab()
vocab.idx_to_token = json.load(open(os.path.join(data_dir, 'vocab.json')))
vocab.token_to_idx = {token: idx for idx, token in enumerate(
vocab.idx_to_token)}
bert = d2l.BERTModel(len(vocab), num_hiddens, ffn_num_hiddens, num_heads,
num_blks, dropout, max_len)
# Load pretrained BERT parameters
bert.load_parameters(os.path.join(data_dir, 'pretrained.params'),
ctx=devices)
return bert, vocab
为了便于在大多数机器上进行演示,我们将在本节中加载和微调预训练 BERT 的小型版本(“bert.small”)。在练习中,我们将展示如何微调更大的“bert.base”以显着提高测试准确性。
Downloading ../data/bert.small.torch.zip from http://d2l-data.s3-accelerate.amazonaws.com/bert.small.torch.zip...
16.7.2。微调 BERT 的数据集
对于 SNLI 数据集上的下游任务自然语言推理,我们定义了一个自定义的数据集类SNLIBERTDataset
。在每个示例中,前提和假设形成一对文本序列,并被打包到一个 BERT 输入序列中,如图 16.6.2所示。回想第 15.8.4 节 ,段 ID 用于区分 BERT 输入序列中的前提和假设。对于 BERT 输入序列 ( max_len
) 的预定义最大长度,输入文本对中较长者的最后一个标记会不断被删除,直到max_len
满足为止。为了加速生成用于微调 BERT 的 SNLI 数据集,我们使用 4 个工作进程并行生成训练或测试示例。
class SNLIBERTDataset(torch.utils.data.Dataset):
def __init__(self, dataset, max_len, vocab=None):
all_premise_hypothesis_tokens = [[
p_tokens, h_tokens] for p_tokens, h_tokens in zip(
*[d2l.tokenize([s.lower() for s in sentences])
for sentences in dataset[:2]])]
self.labels = torch.tensor(dataset[2])
self.vocab = vocab
self.max_len = max_len
(self.all_token_ids, self.all_segments,
self.valid_lens) = self._preprocess(all_premise_hypothesis_tokens)
print('read ' + str(len(self.all_token_ids)) + ' examples')
def _preprocess(self, all_premise_hypothesis_tokens):
pool = multiprocessing.Pool(4) # Use 4 worker processes
out = pool.map(self._mp_worker, all_premise_hypothesis_tokens)
all_token_ids = [
token_ids for token_ids, segments, valid_len in out]
all_segments = [segments for token_ids, segments, valid_len in out]
valid_lens = [valid_len for token_ids, segments, valid_len in out]
return (torch.tensor(all_token_ids, dtype=torch.long),
torch.tensor(all_segments, dtype=torch.long),
torch.tensor(valid_lens))
def _mp_worker(self, premise_hypothesis_tokens):
p_tokens, h_tokens = premise_hypothesis_tokens
self._truncate_pair_of_tokens(p_tokens, h_tokens)
tokens, segments = d2l.get_tokens_and_segments(p_tokens, h_tokens)
token_ids = self.vocab[tokens] + [self.vocab['']] \
* (self.max_len - len(tokens))
segments = segments + [0] * (self.max_len - len(segments))
valid_len = len(tokens)
return token_ids, segments, valid_len
def _truncate_pair_of_tokens(self, p_tokens, h_tokens):
# Reserve slots for '', '', and '' tokens for the BERT
# input
while len(p_tokens) + len(h_tokens) > self.max_len - 3:
if len(p_tokens) > len(h_tokens):
p_tokens.pop()
else:
h_tokens.pop()
def __getitem__(self, idx):
return (self.all_token_ids[idx], self.all_segments[idx],
self.valid_lens[idx]), self.labels[idx]
def __len__(self):
return len(self.all_token_ids)
class SNLIBERTDataset(gluon.data.Dataset):
def __init__(self, dataset, max_len, vocab=None):
all_premise_hypothesis_tokens = [[
p_tokens, h_tokens