×

PyTorch教程15.4之预训练word2vec

消耗积分:0 | 格式:pdf | 大小:0.14 MB | 2023-06-05

张艳

分享资料个

我们继续实现 15.1 节中定义的 skip-gram 模型。然后我们将在 PTB 数据集上使用负采样来预训练 word2vec。首先,让我们通过调用函数来获取数据迭代器和这个数据集的词汇表 ,这在第 15.3 节d2l.load_data_ptb中有描述

import math
import torch
from torch import nn
from d2l import torch as d2l

batch_size, max_window_size, num_noise_words = 512, 5, 5
data_iter, vocab = d2l.load_data_ptb(batch_size, max_window_size,
                   num_noise_words)
Downloading ../data/ptb.zip from http://d2l-data.s3-accelerate.amazonaws.com/ptb.zip...
import math
from mxnet import autograd, gluon, np, npx
from mxnet.gluon import nn
from d2l import mxnet as d2l

npx.set_np()

batch_size, max_window_size, num_noise_words = 512, 5, 5
data_iter, vocab = d2l.load_data_ptb(batch_size, max_window_size,
                   num_noise_words)

15.4.1。Skip-Gram 模型

我们通过使用嵌入层和批量矩阵乘法来实现 skip-gram 模型。首先,让我们回顾一下嵌入层是如何工作的。

15.4.1.1。嵌入层

如第 10.7 节所述,嵌入层将标记的索引映射到其特征向量。该层的权重是一个矩阵,其行数等于字典大小 ( input_dim),列数等于每个标记的向量维数 ( output_dim)。一个词嵌入模型训练好之后,这个权重就是我们所需要的。

embed = nn.Embedding(num_embeddings=20, embedding_dim=4)
print(f'Parameter embedding_weight ({embed.weight.shape}, '
   f'dtype={embed.weight.dtype})')
Parameter embedding_weight (torch.Size([20, 4]), dtype=torch.float32)
embed = nn.Embedding(input_dim=20, output_dim=4)
embed.initialize()
embed.weight
Parameter embedding0_weight (shape=(20, 4), dtype=float32)

嵌入层的输入是标记(单词)的索引。对于任何令牌索引i,它的向量表示可以从ith嵌入层中权重矩阵的行。由于向量维度 ( output_dim) 设置为 4,因此嵌入层返回形状为 (2, 3, 4) 的向量,用于形状为 (2, 3) 的标记索引的小批量。

x = torch.tensor([[1, 2, 3], [4, 5, 6]])
embed(x)
tensor([[[-0.6501, 1.3547, 0.7968, 0.3916],
     [ 0.4739, -0.0944, 1.2308, 0.6457],
     [ 0.4539, 1.5194, 0.4377, -1.5122]],

    [[-0.7032, -0.1213, 0.2657, -0.6797],
     [ 0.2930, -0.6564, 0.8960, -0.5637],
     [-0.1815, 0.9487, 0.8482, 0.5486]]], grad_fn=<EmbeddingBackward0>)
x = np.array([[1, 2, 3], [4, 5, 6]])
embed(x)
array([[[ 0.01438687, 0.05011239, 0.00628365, 0.04861524],
    [-0.01068833, 0.01729892, 0.02042518, -0.01618656],
    [-0.00873779, -0.02834515, 0.05484822, -0.06206018]],

    [[ 0.06491279, -0.03182812, -0.01631819, -0.00312688],
    [ 0.0408415 , 0.04370362, 0.00404529, -0.0028032 ],
    [ 0.00952624, -0.01501013, 0.05958354, 0.04705103]]])

15.4.1.2。定义前向传播

在正向传播中,skip-gram 模型的输入包括形状为(批大小,1)的中心词索引和 形状为(批大小,center的连接上下文和噪声词索引,其中定义在 第 15.3.5 节. 这两个变量首先通过嵌入层从标记索引转换为向量,然后它们的批量矩阵乘法(在第 11.3.2.2 节中描述)返回形状为(批量大小,1, )的输出 。输出中的每个元素都是中心词向量与上下文或噪声词向量的点积。contexts_and_negativesmax_lenmax_lenmax_len

def skip_gram(center, contexts_and_negatives, embed_v, embed_u):
  v = embed_v(center)
  u = embed_u(contexts_and_negatives)
  pred = torch.bmm(v, u.permute(0, 2, 1))
  return pred
def skip_gram(center, contexts_and_negatives, embed_v, embed_u):
  v = embed_v(center)
  u = embed_u(contexts_and_negatives)
  pred = npx.batch_dot(v, u.swapaxes(1, 2))
  return pred

skip_gram让我们为一些示例输入打印此函数的输出形状。

skip_gram(torch.ones((2, 1), dtype=torch.long),
     torch.ones((2, 4), dtype=torch.long), embed, embed).shape
torch.Size([2, 1, 4])
skip_gram(np.ones((2, 1)), np.ones((2, 4)), embed, embed).shape
(2, 1, 4)

15.4.2。训练

在用负采样训练skip-gram模型之前,我们先定义它的损失函数。

15.4.2.1。二元交叉熵损失

根据15.2.1节负采样损失函数的定义,我们将使用二元交叉熵损失。

class SigmoidBCELoss(nn.Module):
  # Binary cross-entropy loss with masking
  def __init__(self):
    super().__init__()

  def forward(self, inputs, target, mask=None):
    out = nn.functional.binary_cross_entropy_with_logits(
      inputs, target, weight=mask, reduction="none")
    return out.mean(dim=1)

loss = SigmoidBCELoss()
loss = gluon.loss.SigmoidBCELoss()

回想我们在第 15.3.5 节中对掩码变量和标签变量的描述 下面计算给定变量的二元交叉熵损失。

pred = torch.tensor([[1.1, -2.2, 3.3, -4.4]] * 2)
label = torch.tensor([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]])
mask = torch.tensor([[1, 1, 1, 1], [1, 1, 0, 0]])
loss(pred, label, mask) * mask.shape[1] / mask.sum(axis=1)
tensor([0.9352, 1.8462])
pred = np.array([[1.1, -2.2, 3.3, -4.4]] * 2)
label = np.array([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]])
mask = np.array([[1,

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉

评论(0)
发评论

下载排行榜

全部0条评论

快来发表一下你的评论吧 !