The "entanglement" part intuitively makes sense to me, but one bit I always get caught up on the key, query, and value matrices. In every self-attention explanation I've read/watched they tend to get thrown out there and similar to what you did here but leave their usage/purpose a little vague.
Would you mind trying to explain those in more detail? I've heard the database analogy where you start with a query to get a set of keys which you then use to lookup a value, but that doesn't really compute with my mental model of neural networks.
Is it accurate to say that these separate QKV matrices are layers in the network? That doesn't seem exactly right since I think the self-attention layer as a whole contains these three different matrices. I would assume they got their names for a reason that should make it somewhat easy to explain their individual purposes and what they try to represent in the NN.
I'm still trying to get a handle on that part myself... But my ever-evolving understanding goes something like this:
The "Query" matrix is like a mask that is capable of selecting certain kinds of features from the context, while the "Key" matrix focuses the "Query" on specific locations in the context.
Using the Query + Key combination, we select and extract those features from the context matrix. And then we apply the "Value" matrix to those features in order to prepare them for feed-forward into the next layer.
There are multiple "Attention Heads" per layer (GPT-3 had 96 heads per layer), and each Head performs its own separate QKV operation. After applying those 96 Q+K->V attention operations per layer, the results are merged back into a single matrix so that they can be fed-forward into the next layer.
Or something like that...
I'm still trying to grok it myself, and if anyone here shed more light on the details, I'd be very grateful!
I'm still trying to understand, for example, how many QKV matrices are actually stored in a model with a particular number of parameters. For example, in a GPT-NeoX-20B model (with 20 billion params) how many distinct Q, K, and V matrices are there, and what is their dimensionality?
EDIT:
I just read Imnimo's comment below, and it provides a much better explanation about QKV vectors. I learned a lot!
Its basically almost the same as convolution with image processing. For example, you take the 3 channel rgb value of a single pixel, do some math on it with the values of the surrounding pixels with weights, which gives you some value(s). Depending on the dimensions of everything, you can end up with a smaller dimension output, like a single 3 channel RGB value, or a higher dimension output (i.e for a 5x5 kernel, you can end up with a 9x9 output)
The confusing part that doesn't get mentioned is that the input vectors (Q, K, V) are weighted, i.e they are derived from the input with the standard linear transformation where y = A*x+b, where x is the input word, A is the linear layer matrix, and b is the bias. Those weighs are the things that are learned through the training process.
The "entanglement" part intuitively makes sense to me, but one bit I always get caught up on the key, query, and value matrices. In every self-attention explanation I've read/watched they tend to get thrown out there and similar to what you did here but leave their usage/purpose a little vague.
Would you mind trying to explain those in more detail? I've heard the database analogy where you start with a query to get a set of keys which you then use to lookup a value, but that doesn't really compute with my mental model of neural networks.
Is it accurate to say that these separate QKV matrices are layers in the network? That doesn't seem exactly right since I think the self-attention layer as a whole contains these three different matrices. I would assume they got their names for a reason that should make it somewhat easy to explain their individual purposes and what they try to represent in the NN.