Why and how attention works in neural nets

What does it mean for a machine to “pay attention”? Is it possible for dead transistors to do something that seems so alive?

Possibly. ML researchers have been working on neural architectures featuring so-called “attention” mechanisms. They are proving useful in different applications of ML, especially tasks with sequence-style inputs or outputs like text.

Attention in Seq2Seq nets

Seq2seq models address the issue of having variable-length inputs or outputs, such as the words of a sentence. This issue is solved by using a recurrent/feed-forward neural network, and using stop symbols to represent the end of an input or output.

The recurrent network “feeds forward” the output from one “frame” to the next. In the case of text processing, these frames are words. For translation tasks, the input may be an English sentence, and the output is the French translation. The seq2seq architecture uses an encoder-decoder layout in which the encoder processes the input text, processing/encoding each word of the input one by one. The result is a “context vector” which contains all the information of the sentence. The decoder generates a translated version, outputting words terminated by a stop symbol.

One problem is having to store information about a variable-length input in a fixed-length context vector. The seminal paper on attention in seq2seq nets is “Neural Machine Translation by Jointly Learning to Align and Translate” by Bahdanau et al.

They describe the problem:

A potential issue with this encoder–decoder approach is that a neural network needs to be able to compress all the necessary information of a source sentence into a fixed-length vector. This may make it difficult for the neural network to cope with long sentences, especially those that are longer than the sentences in the training corpus.

“Neural Machine Translation by Jointly Learning to Align and Translate” by Bahdanau et al.

They address the issue by giving the decoder block a second look at the input sentence. Looking at the whole sentence for every step of output would be probably be overkill, and would likely make training the network difficult.

Instead, there is a function which compares annotations from each word of the input to the most recent hidden layer of the decoder. The authors call this an alignment function.

The attention model is effective, and the paper advanced the state-of-the-art for neural machine translation.

So why is attention effective for seq2seq translation tasks? Here are my theories.

Input space scales with length of input sentence

During decoding, the decoder unit can receive information directly from any the words of the input sentence. Therefore, the longer the input sentence, the larger the space of potential.

This is as opposed the original encoder-decoder framework which had only a fixed space context vector to filter everything through.

Knowing where to look is easier

The hidden state of the decoder module must at a minimum know “where to look” in the input sentence to get information about the next step of translation.

This must be easier/more compressible than “remembering everything” in the hidden state. This is analogous to the human methods of remembering which chapter of a book to look in, or a search term, rather than remembering all the details all the time.

Bidirectionally-encoded annotations

The attention paper used bi-directional encodings of each word to generate annotations. This means that every annotation contains information about the whole sentence, especially the words right before and after. When an annotation is selected by the alignment model, it can provide big hints to the decoder about where to look next.

Flow of gradient

The flow of gradient is a perhaps under-appreciated aspect of the attention mechanism. The authors do mention it:

Instead, the alignment model directly computes a soft alignment, which allows the gradient of the cost function to be backpropagated through.

“Neural Machine Translation by Jointly Learning to Align and Translate” by Bahdanau et al.

However, they don’t report any testing of the effects of this backpropagation. For example, they could turn off backpropagation from the alignment model to the encoder units and see if there is any difference in performance.

In general, architectures that allow for more gradient flow have performed well, eg: DenseNets, HighwayNets, ResNets, etc.

Attention in visual CNN

Attention-type mechanisms also exist in convolutional neural nets used for image processing. They can be useful in instances where only a small part of the image is relevant, for instance, an a large image with a small object, or an image with a sequence of objects.

In one highly cited paper, “Multiple Object Recognition with Visual Attention” by Ba et al, the authors demonstrate a CNN-RNN hybrid with an attention mechanism which is used to detect street address house numbers.

In this case, the attention model looks at different positions in the image, and the classification model decides whether a digit is present in the current frame or not. The attention model then selects the next place to look.

Attention models may be useful in different visual tasks. However, for many tasks, they have been overtaken by YOLO or FPN networks.

Attention in agent-based learning

One more area where attention-style mechanisms have been studied is in agent-based learning. A seminal paper in curiosity-based learning was “Curiosity-driven Exploration by Self-supervised Prediction” by Prathak et al, based on an agent playing video games.

In real-world, and even game-world situations, there is a lot of stochastic or irrelevant information, that the agent cannot affect, or be affected by. For example, in the real world shadows, or leaves blowing generate a lot of visual information but generally don’t affect the “agent”. In video games, distant enemies may be seen but can’t yet affect the agent.

How can the agent learn which stimulus is important, and which isn’t? Prathak et al use an inverse dynamics model:

Inverse dynamics model

  • Inputs: pixel space at current and previous time steps
  • Intermediate step: compressed representation of pixel space
  • Outputs: attempt to guess action actually taken at previous time step

The intermediate representation of pixel space learned here is incentivized to ignore events that the agent could not have caused. This embedding acts as an attention mechanism, by focusing the agent on only those elements of the game it can affect.

Forward dynamics model

  • Inputs: the pixel space at the current time step.
  • Intermediate step: compressed representation of pixel space, co-learned with inverse dynamics model
  • Outputs: predict the result of the next action taken

The forward dynamics model is used to calculate an intrinsic or curiosity-based reward, which is another interesting discussion, but outside the scope of this post.

Overall Meaning of Attention

The exact mechanism of attention varies in the different cases we examined here. However, what seems common among these is the notion that deciding what stimuli are important is a different task than decoding that stimulus.

The attention mechanism operates at a higher cognitive layer than the classification model. The attention layer decides what the classification layer gets to see.

This mirrors “separation of concerns” or “single responsibility” in the world of computer science. This principle states that one function or object should have a single responsibility.

The extent to which attention models are more accurate or efficient than non-attention models is determined by the extent to which these tasks are different. The more different they are, the better separate modules with separately learned parameters will perform the overall task.


Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s