본문 바로가기
NLP

MQA (Multi-Query Attention) (2019) 논문 리뷰

by 아르카눔 2025. 4. 9.

MQA (Multi-Query Attention)의 논문 이름은 Fast Transformer Decoding: One Write-Head is All You Need다. (링크)

 

저자는 Noam Shazeer다. 

 

MQA는 여러가지 Head들에 대해서 Keys와 Values를 공유함으로써 메모리에 불러오는 비용을 줄인다.

 

 

Abstract

Large sized Keys와 Values를 지속적으로 load함에 있어서 memory-bandwidth 비용이 많이 들고 추론에 있어서 느려진다.

이를 해결하기 위해서 multi-query attention을 제안한다. 서로 다른 heads에 대해서 keys와 values를 sharing한다. 

 

원본 논문에서는 Dot Product Attention, Multi-head Attention, Multi-head Attention (Batched), Multi-Query Attention 식으로 친절하게 단계별로 설명했기 때문에 따라가기 수월한 편이다. 특히 einsum으로 dimension의 변화도 표기해주었다.

파이썬 코드 형식으로 적었는데 주석이나 함수 색깔들은 최대한 원문을 따랐다.

 

 

Dot Product Attention

 

def DotProductAttention ( q , K, V) :
" " " Dot−Product Attention on one query .
Args :
  q : a vector with shape [ k ] 

  K: a ma tr ix with shape [m, k ] 

  V: a ma tr ix with shape [m, v ]

Returns :
  y : a vector with shape [ v ] 

""" 

logits = tf.einsum ( " k , mk −> m" , q , K)

weights = tf.softmax ( logits )

return tf.einsum ( "m, mv −> v " , weights , V)

 

Multi-head Attention

$x$는 input vector다.  

$h$ differenct attention layers ($h$ heads)

$M$은 $m$개의 different iunput vectors다.  

 

def MultiheadAttention ( x, M, P_q, P_k, P_v, P_o) :
" " " Multi-head Attention on one query .
Args :

  x: a vector with shape [d]

  M: a matrix with shape [m, d]

  P_q: a tensor with shape [h, d,  k] 

  P_k: a tensor with shape [h, d,  k] 

  P_v: a tensor with shape [h, d,  v] 

  P_o: a tensor with shape [h, d,  v] 

Returns :
  y : a vector with shape [ d ] 

""" 

q = tf.einsum ( " d , hdk −> hk" , x , P_q)

K = tf.einsum ( " md , hdk −> hmk" , M , P_k)

V = tf.einsum ( " md , hdk −> hmv" , M , P_v)

logits = tf.einsum ( " hk , hmk −> hm" , q , K)

weights = tf.softmax ( logits )

o = tf.einsum ( " hm , hmv −> hv" , weights , V)

y = tf.einsum ( "hv, hdv −> d " , o , P_o)

return y

 

Multi-head Attention (Batched)

$X$는 input vector의 모음인 matrix다.  

$n$은 sequence의 positions다.  

$b$는 batch size다.  

$h$ differenct attention layers ($h$ heads)

$M$은 $m$개의 different iunput vectors다.  

 

def MultiheadAttentionBatched ( X, M, mask, P_q, P_k, P_v, P_o) :
" " " Multi-head Attention .
Args :

  X: a tensor with shape [b, n. d]

  M: a tensor with shape [b, m, d]

  mask: a tensor with shape [b, h, n, m]

  P_q: a tensor with shape [h, d,  k] 

  P_k: a tensor with shape [h, d,  k] 

  P_v: a tensor with shape [h, d,  v] 

  P_o: a tensor with shape [h, d,  v] 

Returns :
  Y : a vector with shape [ b, n, d ] 

""" 

Q = tf.einsum ( " bnd , hdk −> bhnk" , X , P_q)

K = tf.einsum ( " bmd , hdk −> bhmk" , M , P_k)

V = tf.einsum ( " bmd , hdk −> bhmv" , M , P_v)

logits = tf.einsum ( " bhnk , bhmk −> bhnm" , q , K)

weights = tf.softmax ( logits + mask )

O = tf.einsum ( " bhnm , bhmv −> bhnv" , weights , V)

Y = tf.einsum ( "bhnv, hdv −> bnd " , o , P_o)

return Y

 

 

Multi-Query Attention

MQA는 기본적으로 batch 단위다. 

 

$X$는 input vector의 모음인 matrix다.  

$n$은 sequence의 positions다.  

$b$는 batch size다.  

$h$ differenct attention layers ($h$ heads)

$M$은 $m$개의 different iunput vectors다.  

 

def MultiqueryAttentionBatched ( X, M, mask, P_q, P_k, P_v, P_o) :
" " " Multi-head Attention .
Args :

  X: a tensor with shape [b, n. d]

  M: a tensor with shape [b, m, d]

  mask: a tensor with shape [b, h, n, m]

  P_q: a tensor with shape [h, d,  k] 

  P_k: a tensor with shape [d,  k] 

  P_v: a tensor with shape [d,  v] 

  P_o: a tensor with shape [h, d,  v] 

Returns :
  Y : a vector with shape [ b, n, d ] 

""" 

Q = tf.einsum ( " bnd , hdk −> bhnk" , X , P_q)

K = tf.einsum ( " bmd , dk −> bmk" , M , P_k)

V = tf.einsum ( " bmd , dk −> bmv" , M , P_v)

logits = tf.einsum ( " bhnk , bmk −> bhnm" , q , K)

weights = tf.softmax ( logits + mask )

O = tf.einsum ( " bhnm , bhmv −> bhnv" , weights , V)

Y = tf.einsum ( "bhnv, hdv −> bnd " , o , P_o)

return Y

 

보면 알겠지만 원래 [h x d x k] 사이즈였던 P_k 텐서가 [d x k] 텐서로 변했음을 알 수 있다.

P_v 텐서도 비슷하게 [h x d x v] 사이즈였던 텐서가 [d x v] 텐서로 변했음을 알 수 있다.

 

 

Multihead Attention Incremental은 병렬처리를 위해 각각 처리한 다음 concat하는 과정이 포함되어 있는데 여기서는 생략한다.

논문의 원문을 보면 쉽게 이해 가능하다.

 

 

 

 

'NLP' 카테고리의 다른 글

Longformer (2020) 논문 리뷰  (0) 2025.04.09
GLU variants (2020) 논문 리뷰  (0) 2025.04.09
GPT 2 (2019) 논문 리뷰  (0) 2025.04.09
Sentence-BERT (2019) 논문 리뷰  (0) 2025.04.09
Instruct learning, fine tuning, and T5  (0) 2025.01.28