ch03: MultiHeadAttention
vs. MultiHeadAttentionWrapper
context vectors computation efficiency
#559
-
Hi @rasbt, Thanks for the great book and resources! In Chapter 03 (page 90), you mention:
However, based on my understanding, isn't the key matrix in To investigate, I measured the computation time for generating context_vectors with both solutions, using the following input dimensions (I'm on an M3 Pro, and the results were similar for both cpu and mps devices): import torch
import torch.nn as nn
input_1 = torch.rand((10000, 12000)) # 1 vector of 10000 tokens embedded in 12000 dimension
input_2 = torch.rand((10000, 12000))
batch = torch.stack((input_1, input_2), dim=0) # stack the inputs
batch_size, context_length, d_in = batch.shape # batch_size = 2, context_length = 10000, d_in = 12000
d_out = 1000
num_heads = 10
# Class defined in ch03.ipynb
mha_sequencial = MultiHeadAttentionWrapper(
d_in, d_out, context_length, 0.0, num_heads=num_heads
)
# Class defined in ch03.ipynb
# For comparable results, the second parameter is d_out * num_heads
mha_concatenated = MultiHeadAttention(d_in, d_out * num_heads, context_length, 0.0, num_heads=num_heads)
%timeit mha_sequencial(batch)
%timeit mha_concatenated(batch) The results I obtained suggest the opposite of what was stated, with
Could you help clarify if I made any mistakes? If Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Hi there, I am glad you are liking the book, and that's a good question regarding
You are right here. But a large matrix multiplication is often more efficient than multiple smaller matrix multiplications that follow each other sequentially. I think in your case that's because your matrices are now too large 😅. I remember testing it with the training code in Chapter 5 and found that the Btw if you are curious, I also have some comparisons (but with smaller toy examples) here: https://github.com/rasbt/LLMs-from-scratch/tree/main/ch03/02_bonus_efficient-multihead-attention |
Beta Was this translation helpful? Give feedback.
Hi there,
I am glad you are liking the book, and that's a good question regarding
You are right here. But a large matrix multiplication is often more efficient than multiple smaller matrix multiplications that follow each other sequentially. I think in your case that's because your matrices are now too large 😅.
I remember testing it with the training code in Chapter 5 and found that the
MultiHeadAttentionWrapper
was definitely slower. (You can try this experiment by opening the Ch05…