본문 바로가기
Computer Vision

SPPNet(2014) PyTorch Implementation

by 아르카눔 2024. 4. 12.

SPPNet을 Pytorch를 활용하여 구현하고자 한다.

https://arsetstudium.tistory.com/35에서 공부한 내용을 토대로 구현보면 아래와 같다.

 

SPPNet은 R-CNN처럼 Spatial Pyramid Pooling을 제외하고 CNN 구조 자체는 기존과 동일하므로 이는 생략한다.

 

SPP는 max pool 또는 average pool이며 중요한 사항은 바로 window와 stride의 사이즈다.

 

 

개별 SPP를 우선 구현한다.

 

import torch
import torch.nn as nn
import math

class PyramidPoolCell(nn.Module):
    # a is the size of feature map
    # n is the bin size of pyramid pooling
    def __init__(self, a, n, mode='average'):
        super().__init__()
        # Calculate window and stride size by bin and feature size
        self.window = math.ceil(a/n)
        self.stride = math.floor(a/n)
        # Average Pool
        if mode == 'average':
            self.model = nn.AvgPool2d(self.window, self.stride)
        # Max Pool
        else:
            self.model = nn.MaxPool2d(self.window, self.stride)
        #print(f"For {n} pyramid")
        #print("window:",self.window, 'stride:',self.stride)
    
    def forward(self, x):
        return self.model(x)

 

논문에 내용을 토대로 Pyramid Pool Layer에서의 개별 bin을 cell이라 칭하고 이를 구현한다.

print 함수로 window와 stride 사이즈가 맞게 계산되었는지 확인한다.

 

 

class PyramidPoolLayer(nn.Module):
    def __init__(self, features, n_list, mode='average'):
        super().__init__()
        # features shape is B x C x W x H
        # a is the size of feature map
        a = min(features.shape[2], features.shape[3])
        # Prepare multiple pyramid poolings 
        pyramid_list = {}
        for n in n_list:
            pp = PyramidPoolCell(a, n, mode)
            pyramid_list[str(n)] = pp
        # setsttr과 get attr로 각각의 pyramid pool cell을
        # 필요할 때 마다 불러와서 계산한다
        setattr(self, 'pyramidpool', nn.ModuleDict(pyramid_list))
        self.n_list = n_list

    def forward(self, x):
        for idx, n in enumerate(self.n_list):
            if idx < 1:
                y = getattr(getattr(self, 'pyramidpool'), str(n))(x)
                # Flatten으로 fixed-length representation을 준비한다
                y = torch.flatten(y, start_dim = 1)
                #print(y.shape)
            else:
                y_new = getattr(getattr(self, 'pyramidpool'), str(n))(x)
                # Flatten으로 fixed-length representation을 준비한다
                y_new = torch.flatten(y_new, start_dim = 1)
                #print(y_new.shape)
                # Concat featrues from pyramid pooling
                y = torch.concat([y, y_new], dim=1)
        return y

 

setattr과 getattr 함수를 이용하여 미리 구현된 pyramid pooling cells를 그때 그때 불러와서 구한다음,

flatten하고 concat해서 fixed-length representation을 구현한다.

 

 

 

pyramid = PyramidPoolLayer(input, [1, 2, 3], 'max')
>> 
For 1 pyramid
window: 13 stride: 13
For 2 pyramid
window: 7 stride: 6
For 3 pyramid
window: 5 stride: 4

 

PyramidPoolLayer를 선언할 때 PyramidPoolCell의 print 함수를 살려서 window와 stride를 확인한다.

논문의 Figure 4와 같은 결과임을 확인할 수 있다.

 

# features shape is B x C x W x H
input = torch.randn(10, 256, 13, 13)

result = pyramid(input)
>>
torch.Size([10, 256])
torch.Size([10, 1024])
torch.Size([10, 2304])

result.shape
torch.Size([10, 3584])

 

논문에서 제시한 13x13x256의 features를 10개의 batch size로 설정하여 PyramidPoolLayer에 집어넣는다.

그 결과는 10 x 3584가 나온다.

 

이제 뒤따라 나올 FC layer의 in_features의 size를 3584로 구현하면 된다.

 

그리고 추가적으로 알아볼 내용은 이전에 https://arsetstudium.tistory.com/26에서 AlexNet pre-trained 구조를 살펴볼 때 나왔던 AdaptiveMaxPool2d와 AdaptiveAvgPool2d다. 

 

SPPNet은 SPP layer를 도입하되 output size를 고려하여 bins의 개수와 사이즈를 설정해야 했다.

아니면 bins의 개수와 사이즈에 맞게 FC layer의 in_features 수를 설정해야 했다.

 

하지만 AdaptivePool은 outptu size 자체를 맞추기 때문에 이미지 사이즈의 문제를 신경쓰지 않아도 된다.

가령, 224x224 크기의 이미지를 넣은 AlexNet의 feature output의 최종적으로 6x6x256이라고 해보자.

이미지 사이즈 256x256이나 512x512라고 하더라도, features 항목 뒤에 AdaptiveAvgPool2d(6)을 넣어주면 알아서 6x6x256의 features를 만들게 되고 이는 원래 AlexNet에서 설정한 FC layer의 in_features와 동일해 진다. 

 

따라서, 단순히 이미지 사이즈가 원래의 모델과 다르기 때문에 생기는 문제의 해결을 위해서하면 adaptive pooling을 쓰면 되고, 이미지의 사이즈 뿐이 아니라 다양한 receptive fields의 features를 concat하고 싶다면 pyramid pooling을 사용하면 된다.

 

 

 

 

References:

https://discuss.pytorch.org/t/use-of-getattr-in-forward-function/46755/3

https://discuss.pytorch.org/t/elegant-implementation-of-spatial-pyramid-pooling-layer/831/3