在一般的 seq2seq 问题中,如机器翻译(第 10.5 节),输入和输出的长度不同且未对齐。处理这类数据的标准方法是设计一个编码器-解码器架构(图 10.6.1),它由两个主要组件组成:一个 编码器,它以可变长度序列作为输入,以及一个 解码器,作为一个条件语言模型,接收编码输入和目标序列的向左上下文,并预测目标序列中的后续标记。
让我们以从英语到法语的机器翻译为例。给定一个英文输入序列:“They”、“are”、“watching”、“.”,这种编码器-解码器架构首先将可变长度输入编码为一个状态,然后对该状态进行解码以生成翻译后的序列,token通过标记,作为输出:“Ils”、“regardent”、“.”。由于编码器-解码器架构构成了后续章节中不同 seq2seq 模型的基础,因此本节将此架构转换为稍后将实现的接口。
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
10.6.1。编码器
在编码器接口中,我们只是指定编码器将可变长度序列作为输入X
。实现将由继承此基类的任何模型提供Encoder
。
10.6.2。解码器
在下面的解码器接口中,我们添加了一个额外的init_state
方法来将编码器输出 ( enc_all_outputs
) 转换为编码状态。请注意,此步骤可能需要额外的输入,例如输入的有效长度,这在 第 10.5 节中有解释。为了逐个令牌生成可变长度序列令牌,每次解码器都可以将输入(例如,在先前时间步生成的令牌)和编码状态映射到当前时间步的输出令牌。
class Decoder(nn.Module): #@save
"""The base decoder interface for the encoder-decoder architecture."""
def __init__(self):
super().__init__()
# Later there can be additional arguments (e.g., length excluding padding)
def init_state(self, enc_all_outputs, *args):
raise NotImplementedError
def forward(self, X, state):
raise NotImplementedError
class Decoder(nn.Block): #@save
"""The base decoder interface for the encoder-decoder architecture."""
def __init__(self):
super().__init__()
# Later there can be additional arguments (e.g., length excluding padding)
def init_state(self, enc_all_outputs, *args):
raise NotImplementedError
def forward(self, X, state):
raise NotImplementedError
class Decoder(nn.Module): #@save
"""The base decoder interface for the encoder-decoder architecture."""
def setup(self):
raise NotImplementedError
# Later there can be additional arguments (e.g., length excluding padding)
def init_state(self, enc_all_outputs, *args):
raise NotImplementedError
def __call__(self, X, state):
raise NotImplementedError
class Decoder(tf.keras.layers.Layer): #@save
"""The base decoder interface for the encoder-decoder architecture."""
def __init__(self):
super().__init__()
# Later there can be additional arguments (e.g., length excluding padding)
def init_state(self, enc_all_outputs, *args):
raise NotImplementedError
def call(self, X, state):
raise NotImplementedError
10.6.3。将编码器和解码器放在一起
在前向传播中,编码器的输出用于产生编码状态,解码器将进一步使用该状态作为其输入之一。
class EncoderDecoder(d2l.Classifier): #@save
"""The base class for the encoder-decoder architecture."""
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder
def forward(self, enc_X, dec_X, *args):
enc_all_outputs = self.encoder(enc_X, *args)
dec_state = self.decoder.init_state(enc_all_outputs, *args)
# Return decoder output only
return self.decoder(dec_X, dec_state)[0]
class EncoderDecoder(d2l.Classifier): #@save
"""The base class for the encoder-decoder architecture."""
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder
def forward(self, enc_X, dec_X, *args):
enc_all_outputs = self.encoder(enc_X, *args)
dec_state = self.decoder.init_state(enc_all_outputs, *args)
# Return decoder output only
return self.decoder(dec_X, dec_state)[0]
class EncoderDecoder(d2l.Classifier): #@save
"""The base class for the encoder-decoder architecture."""
encoder: nn.Module
decoder: nn.Module
training: bool
def __call__(self, enc_X, dec_X, *args):
enc_all_outputs = self.encoder(enc_X, *args, training=self.training)
dec_state = self.decoder.init_state(enc_all_outputs, *args)
# Return decoder output only
return self.decoder(dec_X, dec_state, training=self.training)[0]
class EncoderDecoder(d2l.Classifier): #@save
"""The base class for the encoder-decoder architecture."""
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder
def call(self, enc_X, dec_X, *args):
enc_all_outputs = self.encoder(enc_X, *args, training=True)
dec_state = self.decoder.init_state(enc_all_outputs, *args)
# Return decoder output only
return self.decoder(dec_X, dec_state, training=True)[0]
在下一节中,我们将看到如何应用 RNN 来设计基于这种编码器-解码器架构的 seq2seq 模型。
10.6.4。概括
编码器-解码器架构可以处理由可变长度序列组成的输入和输出,因此适用于机器翻译等 seq2seq 问题。编码器将可变长度序列作为输入,并将其转换为具有固定形状的状态。解码器将固定形状的编码状态映射到可变长度序列。
10.6.5。练习
-
假设我们使用神经网络来实现编码器-解码器架构。编码器和解码器必须是同一类型的神经网络吗?
-
除了机器翻译,你能想到另一个可以应用编码器-解码器架构的应用程序吗?