trees are harlequins, words are harlequins — the transformer … “explained”?

Okay, here’s my promised post on the Transformer architecture.  (Tagging @sinesalvatorem​ as requested)

The Transformer architecture is the hot new thing in machine learning, especially in NLP.  In the course of roughly a year, the Transformer has given us things like:

  • GPT-2, everyone’s new favorite writer-bot, with whose work I am sure you are familiar
  • GPT (the first one) and its superior successor, BERT, which can achieve state-of-the-art results with unprecedented data efficiency on numerous language understanding tasks with almost no hyperparameter tuning – in concrete terms, this means “something that took me, nostalgebraist, a month to do in 2018 now takes me 30 minutes, and the results are better
  • AlphaStar??  There’s still no paper on it yet, AFIAK, but the blog post says it has a Transformer as one of its components EDIT: the AlphaStar paper is out, see my post here for details

This thing is super good.  It honestly spooks me quite a bit, and I’m not usually spooked by new neural net stuff.

However, it doesn’t seem like an intuitive understanding of the Transformer has been disseminated yet – not the way that an intuitive understanding of CNNs and RNNs have.

The original paper introducing it, “Attention Is All You Need,” is suboptimal for intuitive understanding in many ways, but typically people who use the Transformer just cite/link to it and call it a day.  The closest thing to an intuitive explainer than I know of is “The Illustrated Transformer,” but IMO it’s too light on intuition and too heavy on near-pseudocode (including stuff like “now you divide by 8,” as the third of six enumerated “steps” which themselves only cover part of the whole computation!).

This is a shame, because once you hack through all the surrounding weeds, the basic idea of the Transformer is really simple.  This post is my attempt at a explainer.

I’m going to take a “historical” route where I go through some other, mostly older architectural patterns first, to put it in context; hopefully it’ll be useful to people who are new to this stuff, while also not too tiresome to those who aren’t.

(1) Classic fully-connected neural networks.

These treat every distinct input variable as a completely distinctive special and unique snowflake.

When they learn to recognize something involving one particular variable, or set thereof, they do not make any automatic generalizations to other variables / sets thereof.

This makes sense when you’re doing something like a regression in a social scientific or medical study, where your inputs might be (say) demographic variables like “age” or “number of alcoholic drinks per week.”  But it’s really bad when your variables have some known, structured relationship, like a spatial or temporal layout.

If they’re pixels in an image, say, a fully-connected network wouldn’t be able to learn a pattern like “brighter pixel on left, darker pixel on right”; it’d have to separately learn “(0, 0) brighter than (1, 0),” and “(1, 0) brighter than (2, 0),” and (0, 1) brighter than (1, 1),” and so on.

(2) Convolutional neural nets (CNNs).

These know about the spatial layout of the inputs.  They view inputs in relative terms: they don’t learn things about “the pixel at position (572, 35),” they learn things about “the pixel at the center of where I’m looking,” “the pixel to its left,” etc. Then they slide along “looking” at different parts of the image, looking for the same relative-to-the-center patterns in each area.

CNNs really have two separate qualities that distinguish them from fully-connected nets: weight sharing and locality.

  • Weight sharing means “there’s some fixed set of computations you apply, defined relative to a central position, and you apply them at each position.”
  • Locality means each of the computations can only look at things fairly close to the center. For example, you might look for patterns in a 3×3 or 5×5 window.

Weight sharing is crucial for any spatially or temporally structured input, including text. However, locality is not very appropriate for text.

I think about it this way: every salient thing in an image (a dog, a dog’s nose, an edge, a patch of color) can be understood in a self-contained way, without looking outside of that thing.  Images don’t have pronouns, for example, or other systems of reference that require you to look outside a thing to grok that thing.

Except in weird ambiguous cases, it’s usually not like “oh, I see that’s a dog now, but I had to look somewhere outside the dog to know that.”  So you can start with small things and move upward, hierarchically: “ah, that’s an edge –> ah, that’s an oblong thing made of edges –> ah, that’s a dog’s nose –> ah, that’s a dog’s head –> ah, that’s a dog.”  Each thing is defined only by the things inside it (which are, by definition, smaller).

With text, you can’t do that.  A sentence can have a pronoun at one end whose antecedent is all the way at the other, for example.  There’s no way to safely break a sentence down into units that can definitely be understood on their own, and only then linked up to one another.  So locality is out.  (Well, okay, plenty of people do use CNNs on text, including me.  They work well enough for plenty of purposes.  But one can do better.)

(3) Recurrent neural nets (RNNs).

Like CNNs, these sequentially “slide along” the input, doing some version of the same computation step at each position (weight sharing).

But instead of looking at the current position plus a small local window around it, they instead look at

  • the current position
  • their own output, after the last position they looked at

When the input is textual, this feels a lot like “reading” – the RNN processes the the first word, writes down a summary of what it’s gleaned so far, then processes the second word in light of the summary, updates the summary, processes the third word in light of the summary, updates the summary, etc.

Usually people use RNN architectures (LSTMs/GRUs) that can learn about when to “forget” information (remove it from the summary) vs. when to pass it along.  This lets them do things like writing down “there’s this pronoun I still don’t have a referent for,” and pass that along across a potentially long range until it finds a fitting antecedent.

(3b) RNNs are awkward.

Even though they do something that resembles sequential reading, RNNs have an inherently awkward and unwieldy task to do.

An RNN “reads” in one direction only, which creates an asymmetry.  Near the start of a sentence, the output can only use information from a few words; near the end, it can use information from all the words.  (Contrast with a CNN, which processes every position in the same way.)

This is problematic when, for example, words at the start of a sentence can only be understood in light of words that come much later.  The RNN can understand later words in light of earlier ones (that’s the whole idea), but not the reverse.

One can get around this to some extent by using multiple RNN layers, where the later ones are like “additional reading passes,” and also by having an RNN going one way and another one going the other way (”BiLSTMs,” typically found in pre-Transformer state-of-the-art architectures).

But there is still a fundamental awkwardness to the RNN’s structure, where it can only process dependencies between words using a “scratchpad” of finite length, and it has to use the same scratchpad for handling all dependencies, short- and long-range.

To help see what I mean, read the following sentence:

It was a truly beautiful bike, and though there were bigger and sturdier hogs in Steve’s collection, he’d always favored that one just for the sheer aesthetics.

This is not a very difficult sentence, all in all.  But think about the criss-crossing relationships involved.   When we first hit “bike” (reading left-to-right), we might be imagining a bicycle; it only becomes clear it’s a motorcycle when we get to “hogs,” and that’s only through an indirect implication (that clause is about other “hogs,” not the “bike” mentioned earlier, but implies the “bike” is a “hog” too).  And what about “hog” itself?  In isolation it could be a pig – just as we can only understand “bike” in light of “hog,” we can only understand “hog” in light of “bike” (perhaps with some help from “collection,” although a pig collection is not inconceivable).

That’s just the word-level ambiguity.  The meaning of the sentence is interlaced in a similar way.  Only when you get to “sheer aesthetics” do you really get why the sentence began with “it was a truly beautiful...”, and only after linking that all together can we perceive what the sentence implies about this guy named Steve, his motorcycle collection, and his attitudes re: same.

An RNN has to bumble its way through this sentence, writing things down on a scratchpad of limited size, opportunistically omitting some details and hoping they won’t turn out to matter.  “Okay, gotta remember that this ‘it’ needs an antecedent [scribbles] … I have limited room, I can probably get away with a hazy sense of something about beauty and bicycles, oh jeez here’s an ‘and though’ clause, uh, what’s this about pigs? I’ve got [checks notes] ‘pretty bicycle past tense,’ no pigs there, I’ll add ‘pigs’ and hope it’s explained later… what was the ‘though’ about, anyway? oh dear…”

(4) Attention.

Attention was (I believe) first invented for processing pairs of texts, like for textual entailment (”does Sentence 1 imply Sentence 2, contradict it, or neither?”)

In that case, people wanted a model that would compare each word/phrase in Sentence 1 with each word/phrase in Sentence 2, to figure out which ones were probably referring to the same topic, or whatever.

Attention is just that.  You have two sequences of words (or generally “positions”).  You form a big grid, with one sequence on the vertical axis and one on the horizontal, so each cell contains one possible pair of words.  Then you have some way of deciding when words “match,” and for each word, you do some computation that combines it with the ones it “matched with.”

As I said, this was invented for cases where you had two different texts you wanted to compare.  But you can just as easily do it to compare the same text with itself.  (This is called “self-attention,” but it’s become so common that people are tending to drop the “self-” part.)

You can see how this would help with stuff like resolving pronouns or word ambiguities.  Instead of keeping every piece of ambiguity on some finite scratchpad, and hoping you’ll have enough room, you immediately link every word with every other word that might help illuminate it.  Pronouns and noun phrases link up in one step.  “Bike” and “hog” link up in one step.

(4b) More on attention.

There are lots of ways attention can work.  Here’s a sketch of one way, the one used in the Transformer (roughly).

Imagine the words are trying to pair up on a dating site.  (I know, I know, but I swear this is helpful.)  For each word, you compute:

  • a key: the word’s “dating profile” (e.g. the profile for “bike” might include “I’m a noun of neuter gender”) 
  • a query: what the word is looking for when it trawls dating profiles (a pronoun like “it” might say “I’ll match with nouns of neuter gender”)
  • a value: other info about what the word means, which might not be as relevant to this matching process (say, everything else about what the word “bike” means)

For each word, you use the keys and queries to figure out how much the word matches itself and how much it matches each other word.  Then you sum up the values, weighted by the “match scores.”  In the end, you might get something combining a lot of the word’s original value with a bit of some other words’ values, representing something like “I’m still a pronoun, but also, I stand in for this noun and kinda mean the same thing as it.”

Since words can relate in many different ways, it’s a bit restrictive to have only one key/query/value per word.  And you don’t have to!  You can have as many as you want.  This is called “multi-headed” attention, where the number of keys/queries/values per words is the “number of attention heads.”

(4c) Attention and CNNs.

In some ways, attention is a lot like a CNN.  At each position, it does a computation combining the thing at that position with a few other things elsewhere, while still ignoring most of the surrounding stuff as irrelevant.

But unlike a CNN, the “few other things” don’t have to be close by.  (Attention isn’t “local.”)  And the places you look are not pre-defined and fixed in stone (unlike the fixed CNN “windows,” 3×3 or whatever).  They are dynamically computed, and depend on all the inputs.

(5) The Transformer: attention is all you need.

At first, when people used “self-attention,” they typically applied it just once.  Maybe on top of some RNNs, maybe on top of other things.  Attention was viewed as a thing you sort of sprinkled on top of an existing model to improve it.  Not as the core functional unit, like CNN or RNN layers, which you stack together to make a model.

The Transformer is nothing more than an architecture where the core functional unit is attention.  You stack attention layers on top of attention layers, just like you would do with CNN or RNN layers.

In more detail, a single “block” or “layer” of the Transformer does the following:

  • An attention step
  • A step of local computation at each word/position, not using any of the others

Then you just stack these blocks.  The first attention step anoints each word with some extra meaning, derived from other words that might be related to it.  The first local computation step does some processing on this – this could be stuff like “OK, it seems we found two different nouns that could match this pronoun; let’s cook up a query designed to help decide between them.”  Then the next attention step takes the new, better-understood version of each word and reaches out again to all the others, with new, more sophisticated requests for context.  And again, and again.

Interestingly, the sequence is the same size at every layer.  There’s always one position per word (or “wordpiece” or “byte pair” or whatever – generally these models divide up texts not quite at word boundaries).  But the value stored for each position, which starts out as just the word, becomes a progressively more “understood” or “processed” thing, representing the word in light of more and more sophisticated reverberations from the context.

(5b) Positional encoding; the Transformer and CNNs.

Oh, there’s one more thing.  The input to the Transformer includes, not just the word at each position, but a running counter that just says “this is word #1,” “this is word #2,” etc.

If this wasn’t there, the Transformer wouldn’t be able to see word order (!).  Attention qua attention doesn’t care where things are, just what they are (and “want”).  But since the Transformer’s attention can look at this running counter, it can do things like “this word is looking for words that are closeby.”

The running counter allows the Transformer to learn, in principle, the same fixed local filters a CNN would use.  CNN-like behavior is one limiting case, where the attention step ignores the words and uses only the position counter.  (The other limiting case is attention without a position counter, where only the words are used.)

I’m papering over some technical details, but morally speaking, this means the Transformer model class is a superset of the CNN model class.  Any set of local CNN filters can be represented as a particular attention computation, so the space of Transformer models includes CNNs as a special case.  This suggests that we have enough text data, at least for pre-training (more on that in a moment), that CNNs for text are too restrictive, and a more flexible model is more appropriate.  We don’t need that inductive bias anymore; we can learn that pattern when (and only when) it’s appropriate.

(6) One model to rule them all.

GPT is a Transformer.  GPT-2 is a Transformer.  BERT is a Transformer.  In fact, they’re the same Transformer.

You don’t have many choices when designing a Transformer.  You stack some number of the “blocks” described above.  You choose how many blocks to stack.  You choose how to much data to store in your representation of each position (the ”hidden size”).  There are some details, but that’s p much it.

  • If you stack up 12 blocks, with hidden size 768, that’s “GPT,” or “GPT-2 small,” or “BERT_BASE.”
  • If you stack up 24 blocks, with hidden size 1024, that’s “GPT-2 medium,” or “BERT_LARGE.”  (Oh, also BERT_LARGE has 16 attention heads, where AFAIK the others all have 12.)
  • Then there’s the “large” GPT-2, with 36 blocks and hidden size 1280.  And the scary full GPT-2, with 48 blocks and hidden size 1600.

None of this involves any extra design above and beyond the original Transformer.  You just crank up the numbers!

The challenging part is in getting lots of good training data, and in finding a good training objective.  OpenAI credits GPT-2′s success to its size and to its special Reddit-based training corpus.

BERT doesn’t have that magic corpus and only gets half as big, but it has a different training objective where instead of predicting the next word from a partial sentence, it predicts “masked out” words from surrounding context.  This leads to improved performance on the things the original GPT was trying to do.  (You can kind of see why.  Predicting the next word is inherently a guessing game: if our words were fully predictable, there’d be no reason to speak at all.  But if you’re filling in missing words with context on both sides, there’s less of a “guess the writer’s intent” element.)

I’m excited to see what happens when someone combines BERT’s objective with GPT-2′s corpus.

[EDIT 7/21/20: wanted to add an explicit note saying that this particular section ignores the encoder vs. decoder distinction.  BERT has an encoder attention-masking structure, GPT-n has a decoder attention-masking structure.  For the purposes of this big-picture intro, this difference is minor enough to be a distraction, but it *is* a difference.]

(6b) Just add water.

What’s especially exciting about BERT is that you, as someone with some specialized NLP task at hand, don’t really need to do anything.  Attention is all you need.  BERT, already trained on its corpus and brimming with linguistic knowledge, is all you need.

All you do is hook up one of BERT’s outputs to your target of interest, and do a little backpropagation.  Not a lot; the Transformer weights already know most of what they need to know.  You do 2 to 4 passes over your data set, with a learning rate between 2e-5 and 5e-5, and a batch size of 16 or 32.  With a reasonably sized data set you can try all those permutations on an ordinary GPU in like a day.

It’s wild that this works.  And even wilder that the same damn thing is freaking us all out with its writing prowess.  Two years ago we had a menagerie of complicated custom architectures with a zillion hyperparameters to tune.  These days, we have Transformers, which are just stacks of identically shaped attention blocks.  You take a Transformer off the shelf, hook it up, press a button, and go get lunch.  When you get back from lunch, you’ll have something competitive with any customized, complicated, task-specific model laboriously cooked up by a 2016 grad student.

Source link