PyTorch의 repeat와 repeat_interleave와 expand의 주요 차이점을 설명한다.
1. torch.repeat:
1차원 텐서
x = torch.tensor([1, 2, 3])
result = x.repeat(3)
>> tensor([1, 2, 3, 1, 2, 3, 1, 2, 3])
2차원 텐서
x = torch.tensor([[1, 2, 3],
[4, 5, 6]])
x.repeat(2, 1)
>> tensor([[1, 2, 3],
[4, 5, 6],
[1, 2, 3],
[4, 5, 6]])
이때 repeat 안에 들어가는 숫자는 복사하고자 하는 텐서의 차원이 같아야 한다.
2차원 텐서에 대해서 아래와 같이 입력하면 에러가 발생한다.
x.repeat(2)
>> Cell In[4], line 1
----> 1 x.repeat(2)
RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor
- 전체 텐서를 그대로 반복한다
- 패턴 전체가 순차적으로 반복한다
2. torch.repeat_interleave:
1차원 텐서
x = torch.tensor([1, 2, 3])
torch.repeat_interleave(x, 3)
>> tensor([1, 1, 1, 2, 2, 2, 3, 3, 3])
2차원 텐서
x = torch.tensor([[1, 2, 3],
[4, 5, 6]])
torch.repeat_interleave(x, 3)
>> tensor([1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6])
모든 원소를 3번씩 반복하되 이를 flatten한다. 기본적인 repeat_interleave의 설정이다.
x = torch.tensor([[1, 2, 3],
[4, 5, 6]])
torch.repeat_interleave(x, 3, dim = 0)
>> tensor([[1, 2, 3],
[1, 2, 3],
[1, 2, 3],
[4, 5, 6],
[4, 5, 6],
[4, 5, 6]])
이때 dim = 0의 옵션을 주면 0번째 dim의 요소 전체를 하나의 원소로 보고 3번씩 반복한다.
x = torch.tensor([[1, 2, 3],
[4, 5, 6]])
torch.repeat_interleave(x, 3, dim = 1)
>> tensor([[1, 1, 1, 2, 2, 2, 3, 3, 3],
[4, 4, 4, 5, 5, 5, 6, 6, 6]])
이때 dim = 1의 옵션을 주면 0번째 dim 내에서 1번째 dim인 column 단위로 요소 원소들을 3번씩 반복한다.
- 데이터를 원소 단위로 복사된다
- 원소 단위로 패턴이 복사된다
3. torch.expand:
torch.expand는 1인 dim에 대해서만 가능하다.
그러니까 dim이 [3, 1, 2]일 때 1에 대해서만 적용이 가능하다.
1차원 텐서
x = torch.tensor([1, 2, 3])
x.expand(-1, 3)
>> RuntimeError: The expanded size of the tensor (-1) isn't allowed in a leading, non-existing dimension 0
1차원 텐서의 경우 위의 코드처럼 에러가 발생함을 알 수 있다.
2차원 텐서
x = torch.tensor([1, 2, 3])
x = x.unsqueeze(0)
print(x.shape)
>> torch.Size([1, 3])
x.expand(2, -1)
>> tensor([[1, 2, 3],
[1, 2, 3]])
반면에 x를 unsqueeze해서 새로운 dimension을 0번째에 삽입하고 이를 토대로 확장을 하면 위 결과처럼 복제가 되는걸 알 수 있다.
expand에서 -1을 해당하는 디멘션의 크기를 그대로 유지하겠다는 뜻이다.
여기서는 [d1, d2]에서 d2를 3으로 유지하겠다는 뜻이다.
그리고 pytorch의 공식 문서를 보면 expand는 새롭게 메모리를 할당하지 않는다고 한다.
현재 존재하는 텐서에 대한 view만을 새로 생성한다고 한다.
References:
https://chickencat-jjanga.tistory.com/175
https://seducinghyeok.tistory.com/9
https://pytorch.org/docs/stable/generated/torch.Tensor.expand.html
'AI Codes > PyTorch' 카테고리의 다른 글
PyTorch Tensor의 차원 변환 (0) | 2025.04.16 |
---|---|
ML와 DL에서의 seed 고정 (0) | 2025.01.02 |
ResNet (2016) PyTorch Implementation (0) | 2024.04.24 |
SPPNet(2014) PyTorch Implementation (0) | 2024.04.12 |
R-CNN (2014) PyTorch Implementation (0) | 2024.04.08 |