×

PyTorch教程11.5之多头注意力

消耗积分:0 | 格式:pdf | 大小:0.14 MB | 2023-06-05

李超

分享资料个

在实践中,给定一组相同的查询、键和值,我们可能希望我们的模型结合来自同一注意机制的不同行为的知识,例如捕获各种范围的依赖关系(例如,较短范围与较长范围)在一个序列中。因此,这可能是有益的
允许我们的注意力机制联合使用查询、键和值的不同表示子空间。

为此,可以使用以下方式转换查询、键和值,而不是执行单个注意力池h独立学习线性投影。那么这些h投影查询、键和值被并行输入注意力池。到底,h 注意池的输出与另一个学习的线性投影连接并转换以产生最终输出。这种设计称为多头注意力,其中每个hattention pooling outputs 是一个 Vaswani et al. , 2017使用全连接层执行可学习的线性变换,图 11.5.1描述了多头注意力。

../_images/多头注意力.svg

图 11.5.1多头注意力,其中多个头连接起来然后进行线性变换。

import math
import torch
from torch import nn
from d2l import torch as d2l
import math
from mxnet import autograd, np, npx
from mxnet.gluon import nn
from d2l import mxnet as d2l

npx.set_np()
import jax
from flax import linen as nn
from jax import numpy as jnp
from d2l import jax as d2l
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
import tensorflow as tf
from d2l import tensorflow as d2l

11.5.1。模型

在提供多头注意力的实现之前,让我们从数学上形式化这个模型。给定一个查询 q∈Rdq, 关键 k∈Rdk和一个值 v∈Rdv, 每个注意力头 hi(i=1,…,h) 被计算为

(11.5.1)hi=f(Wi(q)q,Wi(k)k,Wi(v)v)∈Rpv,

其中可学习参数 Wi(q)∈Rpq×dq, Wi(k)∈Rpk×dkWi(v)∈Rpv×dv, 和f是注意力集中,例如11.3 节中的附加注意力和缩放点积注意力。多头注意力输出是另一种通过可学习参数进行的线性变换Wo∈Rpo×hpv的串联h负责人:

(11.5.2)Wo[h1⋮hh]∈Rpo.

基于这种设计,每个头可能会关注输入的不同部分。可以表达比简单加权平均更复杂的函数。

11.5.2。执行

在我们的实现中,我们为多头注意力的每个头选择缩放的点积注意力。为了避免计算成本和参数化成本的显着增长,我们设置 pq=pk=pv=po/h


声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉

评论(0)
发评论

下载排行榜

全部0条评论

快来发表一下你的评论吧 !