Varol Cagdas Tok

Personal notes and articles.

Recurrent Neural Networks and Attention Mechanisms

Many problems involve sequential data—data where the order of elements matters. Examples: time series like stock prices, sentences in natural language, audio waveforms, and DNA sequences. Standard feedforward neural networks, which assume that data points are independent and identically distributed (i.i.d.), are not suited for these tasks. We explore the evolution of architectures for sequential data, from Recurrent Neural Networks (RNNs) to attention mechanisms in models like the Transformer.

The Challenge of Memory: Recurrent Neural Networks (RNNs)

The challenge in modeling sequences is the need for memory. A model's prediction at the current time step often depends on information from previous time steps. A Recurrent Neural Network (RNN) addresses this by introducing a feedback loop.

In a standard feedforward network, information flows in one direction: from input to output. In an RNN, the activations of the hidden layer at a given time step \(t\) are fed back as an input to the same hidden layer at the next time step, \(t+1\). This recurrent connection allows the hidden state \(\mathbf{z}_t\) to act as a form of memory, accumulating information from all previous inputs in the sequence.

The equations for a simple RNN can be written as:

\[\mathbf{z}_t = \sigma(\mathbf{V}\mathbf{x}_t + \mathbf{U}\mathbf{z}_{t-1})\]

\[\mathbf{y}_t = \sigma(\mathbf{W}\mathbf{z}_t)\]

Here, \(\mathbf{x}_t\) is the input at time \(t\), \(\mathbf{z}_t\) is the hidden state, and \(\mathbf{y}_t\) is the output. The same weight matrices (\(\mathbf{V}\), \(\mathbf{U}\), \(\mathbf{W}\)) are used at every time step. This weight sharing makes the model efficient and lets it generalize to sequences of varying lengths. The network is trained by unfolding it through time and applying a modified version of backpropagation called Backpropagation Through Time (BPTT).

The Vanishing Gradient Problem and LSTMs

Simple RNNs struggle to learn long-range dependencies. During BPTT, gradients are propagated backward through the sequence. For long sequences, these gradients can shrink exponentially until they vanish or explode. The vanishing gradient problem is common and makes it difficult for the network to learn connections between events that are far apart in the sequence.

The Long Short-Term Memory (LSTM) network was designed to solve this problem. An LSTM is a more complex type of recurrent unit that introduces a cell state \(\mathbf{s}_t\), which acts as a conveyor belt for information. The LSTM can add or remove information from this cell state using gates:

  1. Forget Gate: Decides what information from the previous cell state \(\mathbf{s}_{t-1}\) should be discarded.
  2. Input Gate: Decides what new information from the current input \(\mathbf{x}_t\) should be stored in the cell state.
  3. Output Gate: Decides what part of the cell state should be used to compute the hidden state \(\mathbf{z}_t\) and the final output.
  4. Each gate is a small neural network that learns to control the flow of information. This gating mechanism lets LSTMs selectively remember or forget information over long periods, making them better at capturing long-range dependencies than simple RNNs.

    The Encoder-Decoder Architecture and the Bottleneck Problem

    A common framework for sequence-to-sequence tasks, like machine translation, is the encoder-decoder architecture.

    • The encoder, an RNN or LSTM, reads the entire input sequence (e.g., a sentence in English) and compresses it into a single fixed-size context vector, which is the final hidden state of the encoder.
    • The decoder, another RNN or LSTM, is initialized with this context vector and generates the output sequence one element at a time (e.g., the translated sentence in German).

    This architecture has a limitation: the single context vector becomes an information bottleneck. It must encapsulate the entire meaning of the input sequence, which is difficult for long sequences. Information from the beginning of the sequence is lost by the time the encoder finishes processing.

    The Attention Mechanism: A Paradigm Shift

    The attention mechanism was introduced to overcome this bottleneck. The core idea is to allow the decoder to look back at the entire input sequence at every step of the generation process. Instead of relying on a single context vector, the decoder learns to pay "attention" to the most relevant parts of the input sequence when producing each output element.

    Here's how it works:

    1. At each decoding step, the decoder's current hidden state (the "query") is compared with all the hidden states of the encoder (the "keys"). This comparison produces a set of alignment scores.
    2. These scores are converted into a set of weights (the "attention weights") using a softmax function. These weights represent the importance of each input word for generating the current output word.
    3. A context vector is computed as a weighted average of the encoder's hidden states, using the attention weights.
    4. This context vector, which is tailored to the specific decoding step, is then combined with the decoder's hidden state to produce the final output.
    5. Attention provides a shortcut to any part of the input sequence, solving the long-range dependency problem and removing the constraint of the fixed-size context vector.

      Self-Attention and the Transformer

      The paper "Attention Is All You Need" proposed a new architecture, the Transformer, that removes recurrence and convolutions, relying solely on attention.

      The innovation is self-attention. In self-attention, the elements of a single sequence (e.g., the words in an input sentence) pay attention to each other. This lets the model build a context-aware representation of each element by considering its relationships with all other elements in the sequence. The Transformer stacks multiple layers of self-attention (in the encoder) and a combination of self-attention and encoder-decoder attention (in the decoder).

      Conclusion

      The path from simple RNNs to the Transformer reflects an effort to capture the complex, long-range dependencies in sequential data. RNNs introduced the concept of a learned memory. LSTMs provided a mechanism for maintaining that memory over time. The attention mechanism shifted the paradigm from compressing information into a single vector to letting models focus on relevant parts of the input. This principle of self-attention is the engine behind the large language models (LLMs).