PyTorch Products
파이토치에서 지원하는 element-wise product (Hadamrad product), dot product, matrix multiplication 등등을 알아본다.
구체적인 명령어들은 아래와 같다.
- *
- torch.dot
- torch.matmul
- torch.mm
- @
- torch.bmm
- torch.vdot
- torch.outer
- torch.tensordot
Element-wise Product = Hadamard Product
연산자들
- *
1D Tensor 1차원 텐서
# element-wise product
# hadamard product
a = torch.Tensor([1, 2, 3])
b = torch.Tensor([0, 2, 4])
a * b
>> tensor([ 0., 4., 12.])
2D 이상의 텐서
# element-wise product
# hadamard product
a = torch.rand((2, 3))
b = torch.randint(0, 10, (2, 3))
print(f'Shape of A: {a.shape}, B: {b.shape}')
a * b
>> Shape of A: torch.Size([2, 3]), B: torch.Size([2, 3])
tensor([[1.1109, 0.2328, 0.0090],
[0.0000, 6.6022, 0.1797]])
Dot Product
연산자들
- torch.dot
1D 텐서, 벡터 형태만 지원한다.
# dot product
# 1D tensor만 지원한다.
a = torch.rand((3,))
b = torch.randint(0, 10, (3,))
b = b.float()
c = torch.dot(a, b)
c
>> tensor(2.5526)
Matrix Multiplication, Matrix Product, 행렬 곱
torch.matmul, torch.mm, @을 비교한다.
연산자들
- torch.matmul
- torch.mm
- @
# matrix product
a = torch.rand((2, 3))
b = torch.rand((2, 3))
print(f'Shape of A: {a.shape}, B: {b.shape}, transpoed B: {b.T.shape}')
print('')
# matrix product
# NOT support boradcasting
c = torch.mm(a, b.T)
print(c)
print('')
# matrix product
# Support boradcasting
c = torch.matmul(a, b.T)
print(c)
print('')
# matrix product
# Support boradcasting
d = a @ b.T
print(d)
>>
Shape of A: torch.Size([2, 3]), B: torch.Size([2, 3]), transpoed B: torch.Size([3, 2])
tensor([[0.5299, 0.2206],
[1.0264, 0.5442]])
tensor([[0.5299, 0.2206],
[1.0264, 0.5442]])
tensor([[0.5299, 0.2206],
[1.0264, 0.5442]])
torch.matmul은 broadcasting 브로드캐스팅을 지원한다.
torch.mm은 broadcasting 브로드캐스팅을 지원하지 않는다.
@는 torch.matmul과 같아서 broadcasting 브로드캐스팅을 지원한다.
matmul은 텐서의 차원을 고려하지 않아도 되어서 편해보이겠지만,
실제로는 디버깅이 어려울 수 있으므로 애초에 mm을 쓰는게 나을수도 있다.
Tensor Product 텐서 곱
3차원 이상의 텐서에 대한 곱 계산이다.
연산자들
- torch.bmm
# 3d tensor product
# (N x M x K) @ (N x K x P) = (N x M x P)
a = torch.rand((2, 3, 4))
b = torch.rand((2, 4, 5))
print(f'Shape of A: {a.shape}, B: {b.shape}, transpoed B: {b.T.shape}')
# NOT support boradcasting
c = torch.bmm(a, b)
print(c)
print(f'c.shape:: {c.shape}')
>> Shape of A: torch.Size([2, 3, 4]), B: torch.Size([2, 4, 5]), transpoed B: torch.Size([5, 4, 2])
tensor([[[1.0647, 0.7326, 0.4270, 0.4807, 0.9519],
[1.6801, 0.9897, 0.9374, 0.8012, 1.3327],
[2.1697, 1.4453, 1.0139, 1.1711, 2.3518]],
[[0.7599, 1.6414, 0.8666, 1.1576, 0.7004],
[0.9261, 2.1580, 1.2677, 1.4448, 1.0381],
[0.6031, 0.9930, 0.5035, 0.8552, 0.4479]]])
c.shape:: torch.Size([2, 3, 5])
위 예시의 경우 N개의 배치에 대한 (M x K) 와 (K x P)의 행렬곱이라고 볼 수 있다.
# (1, 3, 4)
a = torch.tensor([[
[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]
]])
# (1, 4, 2)
b = torch.tensor([[
[1, 2],
[3, 4],
[5, 6],
[7, 8]
]])
a[0][1] = [5, 6, 7, 8]
b[0][:,1] = [2, 4, 6, 8]
result[0][1][1] = 5*2 + 6*4 + 7*6 + 8*8
= 10 + 24 + 42 + 64 = 140
이런식으로 모든 조합에 대해서 계산한다.
Inner Product 내적
연산자들
- torch.inner
# torch inner product
# 1D 이상의 텐서에도 적용 가능
import torch
a = torch.tensor([[1, 2],
[3, 4]])
b = torch.tensor([[5, 6],
[7, 8]])
c = torch.inner(a, b)
'''
c[0][0] = a[0] * b[0]
c[0][1] = a[0] * b[1]
c[1][0] = a[1] * b[0]
c[1][1] = a[1] * b[1]
17 = 1*5 + 2*6, 23 = 1*7 + 2*8
39 = 3*5 + 4*6, 53 = 3*7 + 4*8
'''
print(c)
>> tensor([[17, 23],
[39, 53]])
1D 이상의 텐서에 대해서 가능한 계산이다.
아래는 2D 텐서에서의 내적 계산의 구체적인 예시다.
c[0][0] = a[0] * b[0]
c[0][1] = a[0] * b[1]
c[1][0] = a[1] * b[0]
c[1][1] = a[1] * b[1]
17 = 1*5 + 2*6
23 = 1*7 + 2*8
39 = 3*5 + 4*6
53 = 3*7 + 4*8
Outer Product 외적
연산자들
- torch.outer
# outer product 외적
v1 = torch.arange(1., 5.)
v2 = torch.arange(1., 4.)
torch.outer(v1, v2)
>> tensor([[ 1., 2., 3.],
[ 2., 4., 6.],
[ 3., 6., 9.],
[ 4., 8., 12.]])
vdot
Complex number 복소수에 대한 dot product 연산이다.
연산자들
- torch.vdot
# torch vdot product
# complex number에 대한 dot product
torch.vdot(torch.tensor([2, 3]), torch.tensor([2, 1]))
a = torch.tensor((1 +2j, 3 - 1j))
b = torch.tensor((2 +1j, 4 - 0j))
print(torch.vdot(a, b))
print(torch.vdot(b, a))
>>
tensor(16.+1.j)
tensor(16.-1.j)
References:
https://sikmulation.tistory.com/90
https://velog.io/@regista/torch.dot-torch.matmul-torch.mm-torch.bmm
https://pytorch.org/docs/stable/generated/torch.dot.html
https://pytorch.org/docs/stable/generated/torch.matmul.html
https://pytorch.org/docs/stable/generated/torch.mm.html
https://pytorch.org/docs/stable/generated/torch.tensordot.html