Encoder-Decoder Architecture

:label:sec_encoder-decoder

As we have discussed in :numref:sec_machine_translation, machine translation is a major problem domain for sequence transduction models, whose input and output are both variable-length sequences. To handle this type of inputs and outputs, we can design an architecture with two major components. The first component is an encoder: it takes a variable-length sequence as the input and transforms it into a state with a fixed shape. The second component is a decoder: it maps the encoded state of a fixed shape to a variable-length sequence. This is called an encoder-decoder architecture, which is depicted in :numref:fig_encoder_decoder.

The encoder-decoder architecture. :label:fig_encoder_decoder

Let us take machine translation from English to French as an example. Given an input sequence in English: “They”, “are”, “watching”, “.”, this encoder-decoder architecture first encodes the variable-length input into a state, then decodes the state to generate the translated sequence token by token as the output: “Ils”, “regardent”, “.”. Since the encoder-decoder architecture forms the basis of different sequence transduction models in subsequent sections, this section will convert this architecture into an interface that will be implemented later.

Encoder

In the encoder interface, we just specify that the encoder takes variable-length sequences as the input X. The implementation will be provided by any model that inherits this base Encoder class.

```{.python .input} from mxnet.gluon import nn

@save

class Encoder(nn.Block): “””The base encoder interface for the encoder-decoder architecture.””” def init(self, kwargs): super(Encoder, self).init(kwargs)

  1. def forward(self, X, *args):
  2. raise NotImplementedError
  1. ```{.python .input}
  2. #@tab pytorch
  3. from torch import nn
  4. #@save
  5. class Encoder(nn.Module):
  6. """The base encoder interface for the encoder-decoder architecture."""
  7. def __init__(self, **kwargs):
  8. super(Encoder, self).__init__(**kwargs)
  9. def forward(self, X, *args):
  10. raise NotImplementedError

Decoder

In the following decoder interface, we add an additional init_state function to convert the encoder output (enc_outputs) into the encoded state. Note that this step may need extra inputs such as the valid length of the input, which was explained in :numref:subsec_mt_data_loading. To generate a variable-length sequence token by token, every time the decoder may map an input (e.g., the generated token at the previous time step) and the encoded state into an output token at the current time step.

```{.python .input}

@save

class Decoder(nn.Block): “””The base decoder interface for the encoder-decoder architecture.””” def init(self, kwargs): super(Decoder, self).init(kwargs)

  1. def init_state(self, enc_outputs, *args):
  2. raise NotImplementedError
  3. def forward(self, X, state):
  4. raise NotImplementedError
  1. ```{.python .input}
  2. #@tab pytorch
  3. #@save
  4. class Decoder(nn.Module):
  5. """The base decoder interface for the encoder-decoder architecture."""
  6. def __init__(self, **kwargs):
  7. super(Decoder, self).__init__(**kwargs)
  8. def init_state(self, enc_outputs, *args):
  9. raise NotImplementedError
  10. def forward(self, X, state):
  11. raise NotImplementedError

Putting the Encoder and Decoder Together

In the end, the encoder-decoder architecture contains both an encoder and a decoder, with optionally extra arguments. In the forward propagation, the output of the encoder is used to produce the encoded state, and this state will be further used by the decoder as one of its input.

```{.python .input}

@save

class EncoderDecoder(nn.Block): “””The base class for the encoder-decoder architecture.””” def init(self, encoder, decoder, kwargs): super(EncoderDecoder, self).init(kwargs) self.encoder = encoder self.decoder = decoder

  1. def forward(self, enc_X, dec_X, *args):
  2. enc_outputs = self.encoder(enc_X, *args)
  3. dec_state = self.decoder.init_state(enc_outputs, *args)
  4. return self.decoder(dec_X, dec_state)
  1. ```{.python .input}
  2. #@tab pytorch
  3. #@save
  4. class EncoderDecoder(nn.Module):
  5. """The base class for the encoder-decoder architecture."""
  6. def __init__(self, encoder, decoder, **kwargs):
  7. super(EncoderDecoder, self).__init__(**kwargs)
  8. self.encoder = encoder
  9. self.decoder = decoder
  10. def forward(self, enc_X, dec_X, *args):
  11. enc_outputs = self.encoder(enc_X, *args)
  12. dec_state = self.decoder.init_state(enc_outputs, *args)
  13. return self.decoder(dec_X, dec_state)

The term “state” in the encoder-decoder architecture has probably inspired you to implement this architecture using neural networks with states. In the next section, we will see how to apply RNNs to design sequence transduction models based on this encoder-decoder architecture.

Summary

  • The encoder-decoder architecture can handle inputs and outputs that are both variable-length sequences, thus is suitable for sequence transduction problems such as machine translation.
  • The encoder takes a variable-length sequence as the input and transforms it into a state with a fixed shape.
  • The decoder maps the encoded state of a fixed shape to a variable-length sequence.

Exercises

  1. Suppose that we use neural networks to implement the encoder-decoder architecture. Do the encoder and the decoder have to be the same type of neural network?
  2. Besides machine translation, can you think of another application where the encoder-decoder architecture can be applied?

:begin_tab:mxnet Discussions :end_tab:

:begin_tab:pytorch Discussions :end_tab: