Can I implement a transformer by just reading the paper?
Mark Henry
my transformer: source code on github, play with a trained
model on replicate
I want to implement a transformer and write a blog post about it to demonstrate competence. The fastest way to
implement a transformer is to copy someone else's implementation, but this does not demonstrate competence. Instead
of reaching a reasonable compromise via learning through imitation, can I brute-force my way through the learning
process while referring to Attention Is All You Need?
For purposes of this project, using others' code for describing the modules and network architecture described in the
paper is Cheating. I've also turned off my Github Copilot completions. Not Cheating: using others' code for
everything else, including boilerplate like the tokenizer and libraries like torch.
I will set my goal very modestly: a next-token predictor that implements the transformer architecture and performs
better than chance.
Step 1: get a dataset
I will use the popular and small WikiText-2.
Step 2: tokenization and embedding
I am happy to use someone else's library for this. Claude:
Swell. Now for an embedding layer:
The output from this embedding layer is currently garbage. It will be trained later. We have the option here of using
pretrained embeddings but I've arbitrarily decided to train my own. Later we can try it with pretrained embeddings
and see if it makes a difference in our final score. Regardless we now have the means to turn each token into a
vector representation.
Figure 1 in Attention Is All You Need (AIAYN) shows two embeddings, one labeled Input and one labeled Output. I don't
understand why there's two inputs to this transformer; in my mind, the input is a list of tokens, and the output is
(logits for) a single token (included in Figure 1 as "Output Probabilities"). Why are there two inputs to the
transformer in Figure 1, and why is one of the inputs labeled "Outputs (shifted right)"?
The next paragraph attempts to explain this:
This isn't helpful to me. I thought that when a token was generated, it was appended to the context, and the
transformer runs again with this new context.
This also is confusing nonsense to me, which means my assumption about how this transformer architecture works must
be incorrect. I think I remember hearing something about encoder-decoder vs decoder-only, so I'm guessing the
transformers I'm familiar with are decoder-only.
A quick conversation with Claude confirms this. AIAYN is showing applications to machine translation, as much
research did at the time; if I want to make a GPT-like model I will want a decoder-only solution.
This alleviates my ignorance somewhat but now I have to make a judgment call. Do I want to strictly implement AIAYN
or do I want to make a gpt-like model? My preference is clear. So I have set an even more complicated challenge for
myself: I'm going to use AIAYN to implement a model whose architecture is not specifically described in AIAYN.
AIAYN mentions that the dimensionality of the embeddings, and of all of the rest of the model, is 512. It's fortunate
that this is arbitrary because I will be using the dimensionality of the BeRT embedding that I'm using (which is
768, that is, 512 + 256).
Step 3: Positional Encoding
Each token's embedding vector needs to be affected according to its position in the input. From an
information-theoretic perspective, it makes sense that combining two pieces of information like this might not
destroy either one, and that we can trust the model to learn to interpret both the meaning of the word and its
position in the context. But it still feels weird. How can it be that we can change the embedding vector without
changing the meaning of the token?
As the paper notes there is an infinite variety of positional encoding strategies. I'll use the one they use.
Claude:
Irritatingly, we are overloading the term "dimension" here. "Dimension" usually refers to the size of a vector (here
d_model is 768) but in the context of the positional encoding it also refers to the element index (that is, 2i or
2i+1).
I'm still pondering the information theory here. The positional encodings' magnitudes are in the range -1 to 1. The
embeddings are random right now because I haven't trained them yet but they have a magnitude of 25 or 30-ish. So
unless something changes under training, the positional encoding will be just a small perturbation, just a little
flavor on the embeddings. I guess that if two concepts ended up too close to each other in the embedding space, then
that risks confusion when positional encoding is applied. However, this situation would be penalized in training,
and gradient descent would move these concepts farther apart until they were no longer too close. At first I didn't
grok this but the more I think about it the more sense it makes.
Step 4: Attention
One second, I have to watch the 3blue1brown video on attention. It's a
great video. Although I have to say that all the explanations of attention I've seen so far tend to slip past me.
But there's nothing like implementation to force you to understand, so, onward.
AIAYN prescribes h=8 attention heads with K dimensionality = d_model / h. In their case dim_K comes out to 64; in my
case 768 / 8 is 96. This seems fine.
Using Claude for this part risks information leaking into the challenge so it is against my self-imposed rules. If
this code seems suddenly unconventional, that's why.
So far so good... as far as I know. There's no way to test this yet. Turning now to the Attention module:
I see from the torch documentation that nn.Linear will be randomly initialized, great. I don't know what the best way
to make a triangular attention mask is in numpy so I did ask Claude for those few lines, while being careful to
avoid "spoilers." I can't test this code yet but I did run random tensors through it in order to fix basic issues
with dimensions and multiplications, and I got pretty far with that.
Step 5: implement feed-forward and linear layers
After gazing into the famous diagram for a bit, I wrote out the following code:
Was very pleased to see that torch provides LayerNorm. That saves me some time.
Changes to Transformer were also implied:
This more or less completes the skeleton of the transformer, which I am very pleased about! Now I just have to rig up
the training.
Step 6: train
I had Claude work up a basic training loop but it let slip that I misunderstood the final Linear layer. It's there to
project the embedding back to the vocabulary size. I guess that makes sense :)
It's working!! I trained for a minute on my Macbook as a proof of concept.
Noise, cool. And then I gave it a proper training run.
I don't think that worked. The next step is to go through the code and make the following changes.
Sidequest 1: check my work to see if I flipped an axis somewhere or misinterpreted a diagram.
Sidequest 2: I wrote the transformer in a very straightforward way: it takes in a list of token ids and outputs
softmaxed probabilities. Sidequest 2 is to implement batching.
Sidequest 3: Learn some more pytorch and speed the transformer up through optimizations.
The following changes were implemented:
implemented batching
I didn't realize LayerNorm was stateful; I thought it was just a math operation like softmaxing. This means I
actually need two LayerNorm objects in the DecoderLayer, oops
I observed a 2x speedup on this GPU:
However, our loss and perplexity are decreasing VERY slowly, if at all—10.3 is basically randomly choosing the
tokens. At this point I flailed around for a day, unsure if architectural issues were making my transformer
untrainable, but unwilling to show the code to Claude and accept spoilers. But then I tried greatly decreasing the
model's size, down to 64 tokens of context, 4 heads of attention, 3 layers, and an embedding size of 256. (I now
realize that if I'm not using the bert embeddings, there's no reason to imitate their embedding dimension of
768.)
It's a classic overfitting pattern! Hallelujah! My transformer is learni—
Nope.
Because the [SEP] token appears in every training example, the model has learned to score free points by going all-in
on [SEP] every round. I added code in the training example tokenizer that snips out all special tokens, and:
Yes! Unstuck again. We see that the model has learned basic word frequencies after three epochs of training. I put
another couple of quarters into the training machine, this time training a model equal in size to AIAYN.
The model plateus quickly before overfitting. This is frustrating. We're going to try using pretrained
embeddings.
This turned out materially identical, although it reached its plateau within one epoch.
Claude is more optimistic than I am, saying that coding mistakes are unlikely, and that if it learned down to 600
perplexity then it is likely just fine and only needs a better training regime. We will start with more data.
WikiText-2 is less than 1% of the size of the corpus used to train in AIAYN. So let's try WikiText-103.
Dang it.
Step 7: give in
At this point I am willing to fold and ask Claude to review my transformer code. I will preregister 90% confidence
that there are no structural issues that are causing the learning to stall out—I've done a lot of troubleshooting of
the code over the last few workdays and although I've deliberately omitted a lot of typical structure that would
greatly improve the performance of the transformer, I didn't see any smoking guns.
Claude said:
Missing Attention Head Scaling: The most significant issue is in the DecoderLayer's forward pass. You're summing
the attention heads without scaling them. This means each head's contribution is effectively amplified by the
number of heads, which can lead to extremely large values. This could cause gradient instability during
training. You should divide by the number of heads:
attention_values = torch.stack([head(input) for head in self.attention_heads]).sum(0) / self.num_attention_heads
Positional Encoding Addition: In the Transformer's forward pass, you're adding positional encodings directly.
While this works, the scale of the positional encodings might be too large compared to the embeddings,
especially early in training. Consider adding a learned scaling factor or using layer normalization after adding
positions.
Missing Dropout: While this might be one of your intentional simplifications, the lack of dropout anywhere in
the model could lead to overfitting, which might explain why training stalls at a certain perplexity.
Mask Padding Tokens: You might want to also mask padding tokens if you're using them (I see you have
pad_token_id in the Transformer class but it's not being used for masking).
Missing Pre-Attention Layer Norm: In the current architecture, layer normalization happens after attention. Many
modern transformers use pre-layer normalization (normalizing before attention and FF), which often leads to more
stable training.
Linear Layer Initialization: There's no explicit initialization for the linear layers (Q, K, V, and feed-forward
networks). While PyTorch's default initialization might work, for transformers it's often beneficial to use
specific initialization scales, especially for the output projection.
Cool! Although unfortunately I lose Bayes points because the attention head scaling counts as a major structural
issue that inhibits training. Let's examine these suggestions.
Baseline. Validation set perplexity 529 after 1 epoch
Add attention head scaling. AIAYN doesn't mention this—it is assumed that no one would be so foolish as to sum the attention values together and call it a day. But multiple heads of attention will completely drown out the influence of the residual connection (input + attention_values in my code), and with simple summing the onus on the attention heads to learn to be very quiet and subtle. So, what results? With attention head scaling, validation set perplexity falls from 410 after 2 epochs to 133 after 4 epochs. That's an amazing improvement! However its completions are still babble, so let's keep driving that perplexity down.
Positional encoding scaling. I will revisit this if I revisit learned embeddings. But for the moment we are using pretrained embeddings so there's not much reason to worry about positional encodings drowning out the embedding values.
Dropout. Dropout was explicitly described in AIAYN but I believe we should be able to get further without it. I may add this in the future.
Masking padding tokens. This we should definitely get into. After implementing this I ran some training again. With attention head scaling and padding token masking, validation set perplexity 580 after 1 epoch, ppl 99 after 7 epochs, 60 after 1 epoch of wikitext-103. 60 perplexity is progress but not yet enough for complete phrases. Because I'm not overfitting, dropout is not necessary for this project. We're not short on data at all; WikiText-103 is huge.
Step 8: change to GPT-2 tokenizer and change to sliding-window dataset strategy
I asked Claude for advice and somehow it convinced me that the gpt-2 tokenizer is "better" for next-token-prediction
tasks. (Evidence for persuasion capabilities??) We also changed the dataset strategy to use a sliding window over
the text, which hews much closer to the task I want to train for, which can only be a good thing. This makes the
dataset colossally larger, 2.4M examples. I started with WikiText-2, but overfitting occurred in a single epoch:
I restarted the training, and early termination led to the lowest loss so far, and the first coherent phrase I've
observed from the model!!
High on confidence, I switched to WikiText-103 let the training run overnight.
Big training run, including discontinuity where I doubled the number of GPUs I was using
After leaving the training running overnight, I realized that
I set up the relationship between the input and the target incorrectly. I was rolling the first token of the
input to the back so that the first token of the input became the last token of the target. This trained the
model to repeat the prompt! Big mistake! It did a great job of repeating the prompt, at least...
Because all examples were full context windows with no padding, the model was uncomfortable with prompts that
begin with a lot of padding (i.e. any prompt shorter than the maximum context length).
I reworked
the dataset/dataloader to rectify the training issue and introduce short prompts into the mix. I will continue
training from the current model weights because I think that despite the training snafu the model has good knowledge
encoded in it that we can leverage—
Uhhh?? Three????? Yes. Yes! After less than twenty seconds of training, the model perplexity dropped from hundreds of
thousands to 3. My explanation for this is that overnight, the model was largely learning next-token prediction for
tokens 0 through 126, but token 127 was forcefully overridden by the training snafu. Once the snafu was resolved,
the model relaxed back into next-token prediction for all outputs.
It immediately started to show signs of overfitting so I stopped training it after just 100k examples. That's all it
took to untangle this horrible mistake and get an incredible perplexity of 3!
These are FANTASTIC completions.
Brief Q&A
How much did it cost? About $50 I think. I did all my coding on a datacrunch.io server with Intellij IDEA
remotely attached, so the meter was running about 40 cents an hour whenever I was working on this project, plus I
left the meter running overnight accidentally once, plus storage costs, plus more expensive GPUs for big training
runs.
Lessons learned? I thought that "getting the code right" would be difficult. It turned out that training was
where I spent most of my time, even with Claude's assistance. Long feedback loops are an absolute killer. It would
have saved some time to carefully check the inputs and targets my code was generating.
Can you ask it questions? Let's see.
Not really. It's enthusiastic at least but needs a fine tune.
You did this without dropout. What were the consequences? Dropout is great in circumstances where you need to
make the most of limited data. But unlike the authors of AIAYN I don't have to worry about overfitting—I have
WikiText-103 which is, as far as a model of this size is concerned, unlimited data.
Can I see some more completions? Sure.
It's not very talkative. I should downweight the eos token or do a fine tune but I probably won't.
Reflections on methodology? AIAYN is not a technical manual nor a tutorial. It assumes that the reader has a
lot of context from previous papers. I set a challenge for myself but I had to erode the rules over time. For this
reason I don't endorse using AIAYN as a technical manual, obviously. However, I do endorse setting yourself a
silly challenge and then eroding it over time until you get the job done. By making this project difficult
for myself, I learned a ton.
Conclusion
I did the architecture in hard mode and the
training in easy mode. I deliberately
simplified the model's architecture but achieved a perplexity of 3 for next-token prediction of the WikiText
dataset using a model size of just 43M parameters. I'm now equipped to answer interview questions about ML
engineering and undertake more ambitious training projects.