MQA (Multi-Query Attention)의 논문 이름은 Fast Transformer Decoding: One Write-Head is All You Need다. (링크)
저자는 Noam Shazeer다.
MQA는 여러가지 Head들에 대해서 Keys와 Values를 공유함으로써 메모리에 불러오는 비용을 줄인다.
Abstract
Large sized Keys와 Values를 지속적으로 load함에 있어서 memory-bandwidth 비용이 많이 들고 추론에 있어서 느려진다.
이를 해결하기 위해서 multi-query attention을 제안한다. 서로 다른 heads에 대해서 keys와 values를 sharing한다.
원본 논문에서는 Dot Product Attention, Multi-head Attention, Multi-head Attention (Batched), Multi-Query Attention 식으로 친절하게 단계별로 설명했기 때문에 따라가기 수월한 편이다. 특히 einsum으로 dimension의 변화도 표기해주었다.
파이썬 코드 형식으로 적었는데 주석이나 함수 색깔들은 최대한 원문을 따랐다.
Dot Product Attention
def DotProductAttention ( q , K, V) :
" " " Dot−Product Attention on one query .
Args :
q : a vector with shape [ k ]
K: a ma tr ix with shape [m, k ]
V: a ma tr ix with shape [m, v ]
Returns :
y : a vector with shape [ v ]
"""
logits = tf.einsum ( " k , mk −> m" , q , K)
weights = tf.softmax ( logits )
return tf.einsum ( "m, mv −> v " , weights , V)
Multi-head Attention
$x$는 input vector다.
$h$ differenct attention layers ($h$ heads)
$M$은 $m$개의 different iunput vectors다.
def MultiheadAttention ( x, M, P_q, P_k, P_v, P_o) :
" " " Multi-head Attention on one query .
Args :
x: a vector with shape [d]
M: a matrix with shape [m, d]
P_q: a tensor with shape [h, d, k]
P_k: a tensor with shape [h, d, k]
P_v: a tensor with shape [h, d, v]
P_o: a tensor with shape [h, d, v]
Returns :
y : a vector with shape [ d ]
"""
q = tf.einsum ( " d , hdk −> hk" , x , P_q)
K = tf.einsum ( " md , hdk −> hmk" , M , P_k)
V = tf.einsum ( " md , hdk −> hmv" , M , P_v)
logits = tf.einsum ( " hk , hmk −> hm" , q , K)
weights = tf.softmax ( logits )
o = tf.einsum ( " hm , hmv −> hv" , weights , V)
y = tf.einsum ( "hv, hdv −> d " , o , P_o)
return y
Multi-head Attention (Batched)
$X$는 input vector의 모음인 matrix다.
$n$은 sequence의 positions다.
$b$는 batch size다.
$h$ differenct attention layers ($h$ heads)
$M$은 $m$개의 different iunput vectors다.
def MultiheadAttentionBatched ( X, M, mask, P_q, P_k, P_v, P_o) :
" " " Multi-head Attention .
Args :
X: a tensor with shape [b, n. d]
M: a tensor with shape [b, m, d]
mask: a tensor with shape [b, h, n, m]
P_q: a tensor with shape [h, d, k]
P_k: a tensor with shape [h, d, k]
P_v: a tensor with shape [h, d, v]
P_o: a tensor with shape [h, d, v]
Returns :
Y : a vector with shape [ b, n, d ]
"""
Q = tf.einsum ( " bnd , hdk −> bhnk" , X , P_q)
K = tf.einsum ( " bmd , hdk −> bhmk" , M , P_k)
V = tf.einsum ( " bmd , hdk −> bhmv" , M , P_v)
logits = tf.einsum ( " bhnk , bhmk −> bhnm" , q , K)
weights = tf.softmax ( logits + mask )
O = tf.einsum ( " bhnm , bhmv −> bhnv" , weights , V)
Y = tf.einsum ( "bhnv, hdv −> bnd " , o , P_o)
return Y
Multi-Query Attention
MQA는 기본적으로 batch 단위다.
$X$는 input vector의 모음인 matrix다.
$n$은 sequence의 positions다.
$b$는 batch size다.
$h$ differenct attention layers ($h$ heads)
$M$은 $m$개의 different iunput vectors다.
def MultiqueryAttentionBatched ( X, M, mask, P_q, P_k, P_v, P_o) :
" " " Multi-head Attention .
Args :
X: a tensor with shape [b, n. d]
M: a tensor with shape [b, m, d]
mask: a tensor with shape [b, h, n, m]
P_q: a tensor with shape [h, d, k]
P_k: a tensor with shape [d, k]
P_v: a tensor with shape [d, v]
P_o: a tensor with shape [h, d, v]
Returns :
Y : a vector with shape [ b, n, d ]
"""
Q = tf.einsum ( " bnd , hdk −> bhnk" , X , P_q)
K = tf.einsum ( " bmd , dk −> bmk" , M , P_k)
V = tf.einsum ( " bmd , dk −> bmv" , M , P_v)
logits = tf.einsum ( " bhnk , bmk −> bhnm" , q , K)
weights = tf.softmax ( logits + mask )
O = tf.einsum ( " bhnm , bhmv −> bhnv" , weights , V)
Y = tf.einsum ( "bhnv, hdv −> bnd " , o , P_o)
return Y
보면 알겠지만 원래 [h x d x k] 사이즈였던 P_k 텐서가 [d x k] 텐서로 변했음을 알 수 있다.
P_v 텐서도 비슷하게 [h x d x v] 사이즈였던 텐서가 [d x v] 텐서로 변했음을 알 수 있다.
Multihead Attention Incremental은 병렬처리를 위해 각각 처리한 다음 concat하는 과정이 포함되어 있는데 여기서는 생략한다.
논문의 원문을 보면 쉽게 이해 가능하다.
'NLP' 카테고리의 다른 글
Longformer (2020) 논문 리뷰 (0) | 2025.04.09 |
---|---|
GLU variants (2020) 논문 리뷰 (0) | 2025.04.09 |
GPT 2 (2019) 논문 리뷰 (0) | 2025.04.09 |
Sentence-BERT (2019) 논문 리뷰 (0) | 2025.04.09 |
Instruct learning, fine tuning, and T5 (0) | 2025.01.28 |