AI Codes/PyTorch

PyTorch Products

아르카눔 2025. 4. 20. 15:26

파이토치에서 지원하는 element-wise product (Hadamrad product), dot product, matrix multiplication 등등을 알아본다.

 

구체적인 명령어들은 아래와 같다. 

 

  • *
  • torch.dot
  • torch.matmul
  • torch.mm
  • @
  • torch.bmm
  • torch.vdot
  • torch.outer
  • torch.tensordot

 

 

Element-wise Product = Hadamard Product

연산자들

  • *

 

 

1D Tensor 1차원 텐서

# element-wise product
# hadamard product

a = torch.Tensor([1, 2, 3])
b = torch.Tensor([0, 2, 4])
a * b

>> tensor([ 0.,  4., 12.])

 

2D 이상의 텐서 

 

# element-wise product
# hadamard product

a = torch.rand((2, 3))
b = torch.randint(0, 10, (2, 3))
print(f'Shape of A: {a.shape}, B: {b.shape}')

a * b

>> Shape of A: torch.Size([2, 3]), B: torch.Size([2, 3])

tensor([[1.1109, 0.2328, 0.0090],
        [0.0000, 6.6022, 0.1797]])

 

Dot Product

연산자들

  • torch.dot

 

1D 텐서, 벡터 형태만 지원한다. 

# dot product
# 1D tensor만 지원한다. 

a = torch.rand((3,))
b = torch.randint(0, 10, (3,))
b = b.float()
c = torch.dot(a, b)
c

>> tensor(2.5526)

 

 

Matrix Multiplication, Matrix Product,  행렬 곱

torch.matmul, torch.mm, @을 비교한다.

 

연산자들

  • torch.matmul
  • torch.mm
  • @

 

# matrix product
a = torch.rand((2, 3))
b = torch.rand((2, 3))
print(f'Shape of A: {a.shape}, B: {b.shape}, transpoed B: {b.T.shape}')
print('')

# matrix product
# NOT support boradcasting
c = torch.mm(a, b.T)
print(c)
print('')

# matrix product
# Support boradcasting
c = torch.matmul(a, b.T)
print(c)
print('')

# matrix product
# Support boradcasting
d = a @ b.T
print(d)

>>
Shape of A: torch.Size([2, 3]), B: torch.Size([2, 3]), transpoed B: torch.Size([3, 2])

tensor([[0.5299, 0.2206],
        [1.0264, 0.5442]])

tensor([[0.5299, 0.2206],
        [1.0264, 0.5442]])

tensor([[0.5299, 0.2206],
        [1.0264, 0.5442]])

 

torch.matmul은 broadcasting 브로드캐스팅을 지원한다.

 

torch.mm은 broadcasting 브로드캐스팅을 지원하지 않는다.

 

@는 torch.matmul과 같아서 broadcasting 브로드캐스팅을 지원한다.

 

matmul은 텐서의 차원을 고려하지 않아도 되어서 편해보이겠지만,

 

실제로는 디버깅이 어려울 수 있으므로 애초에 mm을 쓰는게 나을수도 있다.

 

 

Tensor Product 텐서 곱

3차원 이상의 텐서에 대한 곱 계산이다.

 

연산자들

  • torch.bmm
# 3d tensor product
# (N x M x K) @ (N x K x P) = (N x M x P)

a = torch.rand((2, 3, 4))
b = torch.rand((2, 4, 5))
print(f'Shape of A: {a.shape}, B: {b.shape}, transpoed B: {b.T.shape}')

# NOT support boradcasting
c = torch.bmm(a, b)
print(c)
print(f'c.shape:: {c.shape}')

>> Shape of A: torch.Size([2, 3, 4]), B: torch.Size([2, 4, 5]), transpoed B: torch.Size([5, 4, 2])
tensor([[[1.0647, 0.7326, 0.4270, 0.4807, 0.9519],
         [1.6801, 0.9897, 0.9374, 0.8012, 1.3327],
         [2.1697, 1.4453, 1.0139, 1.1711, 2.3518]],

        [[0.7599, 1.6414, 0.8666, 1.1576, 0.7004],
         [0.9261, 2.1580, 1.2677, 1.4448, 1.0381],
         [0.6031, 0.9930, 0.5035, 0.8552, 0.4479]]])
c.shape:: torch.Size([2, 3, 5])

 

 

위 예시의 경우 N개의 배치에 대한 (M x K) 와 (K x P)의 행렬곱이라고 볼 수 있다.

 

# (1, 3, 4)

a = torch.tensor([[
    [1, 2, 3, 4],
    [5, 6, 7, 8],
    [9, 10, 11, 12]
]]) 

 

# (1, 4, 2)
b = torch.tensor([[
    [1, 2],
    [3, 4],
    [5, 6],
    [7, 8]
]])  

 

a[0][1] = [5, 6, 7, 8]
b[0][:,1] = [2, 4, 6, 8]

result[0][1][1] = 5*2 + 6*4 + 7*6 + 8*8
                = 10 + 24 + 42 + 64 = 140

 

이런식으로 모든 조합에 대해서 계산한다. 

 

 

Inner Product 내적

연산자들

  • torch.inner

 

# torch inner product 
# 1D 이상의 텐서에도 적용 가능 
import torch

a = torch.tensor([[1, 2], 
                  [3, 4]])
b = torch.tensor([[5, 6], 
                  [7, 8]])

c = torch.inner(a, b)

'''
c[0][0] = a[0] * b[0]
c[0][1] = a[0] * b[1]
c[1][0] = a[1] * b[0]
c[1][1] = a[1] * b[1]

17 = 1*5 + 2*6, 23 = 1*7 + 2*8
39 = 3*5 + 4*6, 53 = 3*7 + 4*8

'''

print(c)
>> tensor([[17, 23],
        [39, 53]])

 

 

1D 이상의 텐서에 대해서 가능한 계산이다.

아래는 2D 텐서에서의 내적 계산의 구체적인 예시다.

 

c[0][0] = a[0] * b[0]
c[0][1] = a[0] * b[1]
c[1][0] = a[1] * b[0]
c[1][1] = a[1] * b[1]

17 = 1*5 + 2*6

23 = 1*7 + 2*8
39 = 3*5 + 4*6

53 = 3*7 + 4*8

 

 

Outer Product 외적

연산자들

  • torch.outer

 

# outer product 외적 

v1 = torch.arange(1., 5.)
v2 = torch.arange(1., 4.)

torch.outer(v1, v2)

>> tensor([[ 1.,  2.,  3.],
        [ 2.,  4.,  6.],
        [ 3.,  6.,  9.],
        [ 4.,  8., 12.]])

 

vdot

Complex number 복소수에 대한 dot product 연산이다.

 

연산자들

  • torch.vdot
# torch vdot product
# complex number에 대한 dot product 

torch.vdot(torch.tensor([2, 3]), torch.tensor([2, 1]))
a = torch.tensor((1 +2j, 3 - 1j))
b = torch.tensor((2 +1j, 4 - 0j))

print(torch.vdot(a, b))
print(torch.vdot(b, a))

>>
tensor(16.+1.j)
tensor(16.-1.j)

 

 

 

 

 

 

 

 

References:

https://sikmulation.tistory.com/90

https://velog.io/@regista/torch.dot-torch.matmul-torch.mm-torch.bmm

https://pytorch.org/docs/stable/generated/torch.dot.html

https://pytorch.org/docs/stable/generated/torch.matmul.html

https://pytorch.org/docs/stable/generated/torch.mm.html

https://pytorch.org/docs/stable/generated/torch.tensordot.html