Given the excitement over ChatGPT, I spent part of the winter recess trying to understand the underlying technology of Transformers. After trying various tutorials, I found the best explanation comes in the original 2017 paper, Attention is All you Need. This is my attempt to figure out positional encoding in transformers with some naive questions. Let me know if I'm completely off base.

Earlier models like Recurrent Neural Nets and Convolutional NNs worked under the assumption that information close by was more likely to be correlated. Machine learning seems to improve as the models make fewer and fewer assumptions about the data and transformers use positional data with no predisposed favoritism to nearby inputs.

The attention paper uses the following encoding for positions. Let α(j) = 10000^{-j/d} where d is the dimension of the input embedding, i.e. the number of numbers used to represent each word of the input. We encode position p as d/2 pairs of numbers cos(p α(j)) and sin(p α(j)), for j ranging from 1 to d/2. They chose this function because relative positions are easy to address. We can address a relative position of a fixed k by a linear combination of cos(p α(j)) and sin(p α(j)) using the addition rules of cos and sin.

cos ((p+k) α(j)) = cos(k α(j)) cos(p α(j)) - sin(k α(j)) sin(p α(j))

sin ((p+k) α(j)) = cos(k α(j)) sin(p α(j)) + sin(k α(j)) cos(p α(j))

The d-dimensional vector of position encodings is added to the input embedding.

**Why is the position encodings added to the input embedding?**

I scoured the Internet and can't seem to find a good reason for this, other than it seems to work. Wouldn't the linear combinations to handle relative positions muddle up the input embedding? Since the input embedding is learned, perhaps some parts of the embedding are made zero or very small so the positional embedding stands out. Why not concatenate the two, have separate inputs for the input embedding and the positions? You wouldn't need to fully double the dimension since you would no longer need to match the dimension of the input encoding.

**Why not use complex numbers?**

I see cos(p α(j)) and sin(p α(j)) and immediately think of them as the real and imaginary parts of e^{p α(j) i}. So why not just do the positional encodings as complex numbers? This paper suggests multiplying e^{p α(j) i} with the input embedding, i.e., the input is embedding into the amplitude and the position by the phase. That makes more sense. You can now multiply by e^{p α(k) i} to get the relative position j+k without affecting the input embedding.

I don't see a good reason not to use complex numbers for transformers, given that most learning packages and GPUs can handle complex numbers just fine. Even if you don't want to use complex numbers you could multiply the sin and cos versions of the positional encoding instead of adding to achieve a similar effect.

**How about positional encoding for outputs?**

Transformers output words in order but that makes it harder to relate outputs that are far apart. So why not give positional encoding to the outputs. A post-processor could then put the outputs in the correct order. More generally, how about outputting a program that produces the real output? We know transformers can generate code, and this way can handle operations that transformers normally struggle with, like multiplication or sorting.

A thought that comes to mind is that, because this work was entirely done by Google, they were likely working with TPUs that are designed specifically for the bfloat16 data type, which is half the normal precision of a typical float32. Perhaps bfloat16 arithmetic required them to be more careful about precision here.

ReplyDeleteMy understanding/intuition for why positions are added instead of concatenated is that if they were concatenated, you'd be mandating that the entire network use X dimensions for the token embedding and Y for the position embedding. And in reality, it's better to give the network the ability to push that one way or the other by compressing one into a lower dimension. So as long as it _can_ recover the separate embeddings from the sum, this gives it some more flexibility to play with.

ReplyDelete