×

PyTorch教程15.6之子词嵌入

消耗积分:0 | 格式:pdf | 大小:0.13 MB | 2023-06-05

分享资料个

在英语中,诸如“helps”、“helped”和“helping”之类的词是同一个词“help”的变形形式。“dog”与“dogs”的关系与“cat”与“cats”的关系相同,“boy”与“boyfriend”的关系与“girl”与“girlfriend”的关系相同。在法语和西班牙语等其他语言中,许多动词有超过 40 种变形形式,而在芬兰语中,一个名词可能有多达 15 种格。在语言学中,形态学研究词的形成和词的关系。然而,word2vec 和 GloVe 都没有探索单词的内部结构。

15.6.1。fastText 模型

回想一下单词在 word2vec 中是如何表示的。在skip-gram模型和连续词袋模型中,同一个词的不同变形形式直接由不同的向量表示,没有共享参数。为了使用形态信息,fastText 模型提出了一种子词嵌入方法,其中一个子词是一个字符n-gram Bojanowski等人,2017 年与学习词级向量表示不同,fastText 可以被视为子词级 skip-gram,其中每个中心词由其子词向量的总和表示。

让我们举例说明如何使用单词“where”为 fastText 中的每个中心词获取子词。首先,在单词的首尾添加特殊字符“<”和“>”,以区别于其他子词的前缀和后缀。然后,提取字符n-克从字。例如,当n=3,我们得到所有长度为 3 的子词:“”,以及特殊子词“”。”、“whe”、“her”、“ere”、“re>

在 fastText 中,对于任何单词w, 表示为Gw其所有长度在 3 到 6 之间的子字及其特殊子字的并集。词汇表是所有词的子词的并集。出租zg是子词的向量g在字典中,向量vw为词w作为 skip-gram 模型中的中心词的是其子词向量的总和:

(15.6.1)vw=∑g∈Gwzg.

fastText 的其余部分与 skip-gram 模型相同。与skip-gram模型相比,fastText中的词汇量更大,导致模型参数更多。此外,为了计算一个词的表示,必须将其所有子词向量相加,从而导致更高的计算复杂度。然而,由于具有相似结构的词之间的子词共享参数,稀有词甚至词汇表外的词可能会在 fastText 中获得更好的向量表示。

15.6.2。字节对编码

在 fastText 中,所有提取的子词都必须具有指定的长度,例如36,因此无法预定义词汇量大小。为了允许在固定大小的词汇表中使用可变长度的子词,我们可以应用一种称为字节对编码(BPE) 的压缩算法来提取子词 ( Sennrich et al. , 2015 )

字节对编码对训练数据集进行统计分析,以发现单词中的常见符号,例如任意长度的连续字符。从长度为 1 的符号开始,字节对编码迭代地合并最频繁的一对连续符号以产生新的更长的符号。请注意,为了提高效率,不考虑跨越单词边界的对。最后,我们可以使用子词这样的符号来分词。字节对编码及其变体已用于流行的自然语言处理预训练模型中的输入表示,例如 GPT-2 Radford等人,2019 年和 RoBERTa Liu等人,2019 年. 下面,我们将说明字节对编码的工作原理。

首先,我们将符号词汇表初始化为所有英文小写字符、一个特殊的词尾符号'_'和一个特殊的未知符号'[UNK]'

import collections

symbols = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm',
      'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
      '_', '[UNK]']
import collections

symbols = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm',
      'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
      '_', '[UNK]']

由于我们不考虑跨越单词边界的符号对,我们只需要一个字典raw_token_freqs将单词映射到它们在数据集中的频率(出现次数)。请注意,特殊符号'_'附加到每个单词,以便我们可以轻松地从输出符号序列(例如,“a_ tall er_ man”)中恢复单词序列(例如,“a taller man”)。由于我们从仅包含单个字符和特殊符号的词汇表开始合并过程,因此在每个单词中的每对连续字符之间插入空格(字典的键token_freqs)。换句话说,空格是单词中符号之间的分隔符。

raw_token_freqs = {'fast_': 4, 'faster_': 3, 'tall_': 5, 'taller_': 4}
token_freqs = {}
for token, freq in raw_token_freqs.items():
  token_freqs[' '.join(list(token))] = raw_token_freqs[token]
token_freqs
{'f a s t _': 4, 'f a s t e r _': 3, 't a l l _': 5, 't a l l e r _': 4}
raw_token_freqs = {'fast_': 4, 'faster_': 3, 'tall_': 5, 'taller_': 4}
token_freqs = {}
for token, freq in raw_token_freqs.items():
  token_freqs[' '.join(list(token))] = raw_token_freqs[token]
token_freqs
{'f a s t _': 4, 'f a s t e r _': 3, 't a l l _': 5, 't a l l e r _': 4}

我们定义了以下get_max_freq_pair函数,该函数返回单词中出现频率最高的一对连续符号,其中单词来自输入字典的键token_freqs

def get_max_freq_pair(token_freqs):
  pairs = collections.defaultdict(int)
  for token, freq in token_freqs.items():
    symbols = token.split()
    for i in range(len(symbols) - 1):
      # Key of `pairs` is a tuple of two consecutive symbols
      pairs[symbols[i], symbols[i + 1]] += freq
  return max(pairs, key=pairs.get) # Key of `pairs` with the max value
def get_max_freq_pair(token_freqs):
  pairs = collections.defaultdict(int)
  for token, freq in token_freqs.items():
    symbols = token.split()
    for i in range(len(symbols) - 1):
      # Key of `pairs` is a tuple of two consecutive symbols
      pairs[symbols[i], symbols[i + 1]] += freq
  return max(pairs, key=pairs.get) # Key of `pairs` with the max value

作为一种基于连续符号频率的贪婪方法,字节对编码将使用以下merge_symbols函数合并最频繁的一对连续符号以产生新的符号。

def merge_symbols(max_freq_pair, token_freqs, symbols):
  symbols.append(''.join(max_freq_pair))
  new_token_freqs = dict()
  for token, freq in token_freqs.items():
    new_token = token.replace(' '.join(max_freq_pair),
                 ''.join(max_freq_pair))
    new_token_freqs[new_token] = token_freqs[token]
  return new_token_freqs
def merge_symbols(max_freq_pair, token_freqs, symbols):
  symbols.append(''.join(max_freq_pair))
  new_token_freqs = dict()
  for token, freq in token_freqs.items():
    new_token = token.replace(' '.join(max_freq_pair),
                 ''.join(max_freq_pair))
    new_token_freqs[new_token] = token_freqs[token]
  return new_token_freqs

现在我们在字典的键上迭代执行字节对编码算法token_freqs在第一次迭代中,出现频率最高的一对连续符号是't''a',因此字节对编码将它们合并以产生新的符号'ta'<

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉

评论(0)
发评论

下载排行榜

全部0条评论

快来发表一下你的评论吧 !