본문 바로가기
AI Codes/PyTorch

PyTorch repeat, repeat_interleave, expand 차이

by 아르카눔 2024. 11. 1.

PyTorch의 repeat와 repeat_interleave와 expand의 주요 차이점을 설명한다.

1. torch.repeat:

1차원  텐서 

x = torch.tensor([1, 2, 3])
result = x.repeat(3)

>> tensor([1, 2, 3, 1, 2, 3, 1, 2, 3])

 

 

2차원  텐서 

 

x = torch.tensor([[1, 2, 3],
                  [4, 5, 6]])

x.repeat(2, 1)

>> tensor([[1, 2, 3],
        [4, 5, 6],
        [1, 2, 3],
        [4, 5, 6]])

 

이때 repeat 안에 들어가는 숫자는 복사하고자 하는 텐서의 차원이 같아야 한다.

 

2차원 텐서에 대해서 아래와 같이 입력하면 에러가 발생한다.

x.repeat(2)

>> Cell In[4], line 1
----> 1 x.repeat(2)

RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor

 

 

- 전체 텐서를 그대로 반복한다
- 패턴 전체가 순차적으로 반복한다 

 


2. torch.repeat_interleave:

 

1차원 텐서

 

x = torch.tensor([1, 2, 3])

torch.repeat_interleave(x, 3)
>> tensor([1, 1, 1, 2, 2, 2, 3, 3, 3])

 

 

2차원 텐서

 

x = torch.tensor([[1, 2, 3],
                  [4, 5, 6]])

torch.repeat_interleave(x, 3)

>> tensor([1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6])

 

모든 원소를 3번씩 반복하되 이를 flatten한다. 기본적인 repeat_interleave의 설정이다. 

 

x = torch.tensor([[1, 2, 3],
                  [4, 5, 6]])

torch.repeat_interleave(x, 3, dim = 0)

>> tensor([[1, 2, 3],
        [1, 2, 3],
        [1, 2, 3],
        [4, 5, 6],
        [4, 5, 6],
        [4, 5, 6]])

 

이때 dim = 0의 옵션을 주면 0번째 dim의 요소 전체를 하나의 원소로 보고 3번씩 반복한다. 

 

 

x = torch.tensor([[1, 2, 3],
                  [4, 5, 6]])

torch.repeat_interleave(x, 3, dim = 1)

>> tensor([[1, 1, 1, 2, 2, 2, 3, 3, 3],
        [4, 4, 4, 5, 5, 5, 6, 6, 6]])

 

이때 dim = 1의 옵션을 주면 0번째 dim 내에서 1번째 dim인 column 단위로 요소 원소들을 3번씩 반복한다. 

 

- 데이터를 원소 단위로 복사된다 

- 원소 단위로 패턴이 복사된다

 

 

 

3. torch.expand:

 

torch.expand는 1인 dim에 대해서만 가능하다.

그러니까 dim이 [3, 1, 2]일 때 1에 대해서만 적용이 가능하다.


1차원  텐서 

 

x = torch.tensor([1, 2, 3])

x.expand(-1, 3)
>> RuntimeError: The expanded size of the tensor (-1) isn't allowed in a leading, non-existing dimension 0

 

 

1차원 텐서의 경우 위의 코드처럼 에러가 발생함을 알 수 있다.

 

 

2차원 텐서 

x = torch.tensor([1, 2, 3])
x = x.unsqueeze(0)
print(x.shape)

>> torch.Size([1, 3])

x.expand(2, -1)

>> tensor([[1, 2, 3],
        [1, 2, 3]])

 

반면에 x를 unsqueeze해서 새로운 dimension을 0번째에 삽입하고 이를 토대로 확장을 하면 위 결과처럼 복제가 되는걸 알 수 있다.

 

expand에서 -1을 해당하는 디멘션의 크기를 그대로 유지하겠다는 뜻이다.

 

여기서는 [d1, d2]에서 d2를 3으로 유지하겠다는 뜻이다.

 

그리고 pytorch의 공식 문서를 보면 expand는 새롭게 메모리를 할당하지 않는다고 한다. 

 

현재 존재하는 텐서에 대한 view만을 새로 생성한다고 한다. 

 

 

 

References:

https://cccaaa.tistory.com/33

https://chickencat-jjanga.tistory.com/175

https://seducinghyeok.tistory.com/9

https://pytorch.org/docs/stable/generated/torch.Tensor.expand.html

 

 

'AI Codes > PyTorch' 카테고리의 다른 글

PyTorch Tensor의 차원 변환  (0) 2025.04.16
ML와 DL에서의 seed 고정  (0) 2025.01.02
ResNet (2016) PyTorch Implementation  (0) 2024.04.24
SPPNet(2014) PyTorch Implementation  (0) 2024.04.12
R-CNN (2014) PyTorch Implementation  (0) 2024.04.08