본문 바로가기
Computer Vision

GoogLeNet = Inception v1 (2014) PyTorch Implementation

by 아르카눔 2024. 4. 3.

GoogLeNet = Inception v1을 Pytorch를 활용하여 구현하고자 한다.

https://arsetstudium.tistory.com/31에서 공부한 내용을 토대로 구현보면 아래와 같다.

 

 

Inception Block

 

전체 GoogLeNet 모델을 구축하기에 앞서서 내부의 inception block을 먼저 구현하고자 한다.

 

맨 처음에는 쉽게 구축하기 위해서, 추상적으로 arguments를 짜기 보다는 구체적인 숫자를 이용해서 구현한다.

우선 가장 첫 번째 inception block인 inception 3a를 대상으로 짜본다.

자세한 파라미터의 설정은 아래의 Table 1을 참고하면 된다.

 

 

 

Inception 3a의 input size는 28 x 28 x 192다. 이때 192는 채널의 개수다.

필터(채널)의 사이즈가 1x1은 64, 3x3 reduce(3x3의 앞에 적용되는 1x1)은 96, 3x3은 128,

5x5 reduce(5x5의 앞에 적용되는 1x1)은 16, 5x5은 32, max pool(pool proj)은 32다.

그 다음 1x1, 3x3, 5x5, max pool의 네 가지를 채널(depth) 차원에서 concat해서 다음 레이어에 보내준다.

최종 결과로 28 x 28 x 256의 사이즈가 output으로 나와야 한다.

 

# 3a inception module example
# size of channel in previous layer is 192

class Inception(nn.Module):
    def __init__(self, in_features = 192):
        super().__init__()
        # 1 x 1 conv
        self.one = nn.Sequential(
            # 1 x 1 conv
            nn.Conv2d(in_features, 64, kernel_size = 1, padding = 0),
            nn.ReLU()  
        )
        # 3 x 3 conv
        self.three = nn.Sequential(
            # 1 x 1 conv
            nn.Conv2d(in_features, 96, kernel_size = 1, padding = 0),
            nn.ReLU(),
            # 3 x 3 conv
            # Add padding to match image size
            nn.Conv2d(96, 128, kernel_size = 3, padding = 1),
            nn.ReLU()         
        )
        # 5 x 5 conv
        self.five = nn.Sequential(
            # 1 x 1 conv
            nn.Conv2d(in_features, 16, kernel_size = 1, padding = 0),
            nn.ReLU(),
            # 5 x 5 conv
            # Add padding to match image size
            nn.Conv2d(16, 32, kernel_size = 5, padding = 2),
            nn.ReLU()          
        )
        # max pooling
        self.max = nn.Sequential(
            # 3 x 3 max pool
            # Add padding to match image size
            nn.MaxPool2d(kernel_size = 3, stride = 1, padding = 1), 
            # 1 x 1 conv
            nn.Conv2d(in_features, 32, kernel_size = 1, padding = 0),
            nn.ReLU(),          
        )

    def forward(self, x):
        # Depth Concat by the dimension of channel
        y = torch.concat((self.one(x), self.three(x), self.five(x), self.max(x)), dim=1)
        return y

 

1x1, 3x3, 5x5, max pool의 네 가지 경로를 구축하기 위해서 각자 nn.Sequential로 묶어 준다.

1x1, 3x3, 5x5, max pool의 네 가지를 채널(depth) 차원에서 concat하기 위해서는 b x c x w x h에서 b, w, h의 세 숫자가 모두 동일해야 한다.

b는 batch size로 모든 항목에 대해서 변하지 않는 항목이고 convolution을 취하면 w와 h가 변하게 된다.

이를 방지하기 위해서 padding을 3x3와 5x5, max pool에 대해서 적용해야 한다.

그리고 1x1 Convs에서는 h와 w의 사이즈가 변하지 않으므로 padding을 적용하지 않는다.

그리고 channel의 차원에서 concat 해야 하므로, b x c x w x h에서 2번째 dimension인 c에 대해서 concat을 해야한다.

따라서 torch.concat의 옵션에서 dim=1로 설정해야 한다.

 

# Input shape of inceptiopn 3a
a = torch.rand(5, 192, 28, 28)
a.shape
>> torch.Size([5, 192, 28, 28])

inception = Inception()
inception

>>
Inception(
  (one): Sequential(
    (0): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1))
    (1): ReLU()
  )
  (three): Sequential(
    (0): Conv2d(192, 96, kernel_size=(1, 1), stride=(1, 1))
    (1): ReLU()
    (2): Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
  )
  (five): Sequential(
    (0): Conv2d(192, 16, kernel_size=(1, 1), stride=(1, 1))
    (1): ReLU()
    (2): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (3): ReLU()
  )
  (max): Sequential(
    (0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)
    (1): Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1))
    (2): ReLU()
  )
)

print("one:",inception.one(a).shape)
print("three:",inception.three(a).shape)
print("five:",inception.five(a).shape)
print("max pool:",inception.max(a).shape)
print("Depth concated:",inception(a).shape)

>>
one: torch.Size([5, 64, 28, 28])
three: torch.Size([5, 128, 28, 28])
five: torch.Size([5, 32, 28, 28])
max pool: torch.Size([5, 32, 28, 28])
Depth concated: torch.Size([5, 256, 28, 28])

 

Depth concat을 적용한 최종 결과를 보면 5 x 256 x 28 x 28 (b x c x w x h)이므로, Table에 나온 28 x 28 x 256 (w x h x c)과 동일하다. 올바르게 작성했다고 볼 수 있다. 위 inception class를 모든 inception blocks에 대해 재활용하기 위해서 arguments를 설정했다.

 

 

# 추상화된 aurguments

class Inception(nn.Module):
    def __init__(self, in_features,
                 out_one,
                 mid_three,
                 out_three,
                 mid_five,
                 out_five,
                 out_max):
        super().__init__()

        self.one = nn.Sequential(
            # 1 x 1 conv
            nn.Conv2d(in_features, out_one, kernel_size = 1, padding = 0),
            nn.ReLU()  
        )

        self.three = nn.Sequential(
            # 1 x 1 conv
            nn.Conv2d(in_features, mid_three, kernel_size = 1, padding = 0),
            nn.ReLU(),
            # 3 x 3 conv
            nn.Conv2d(mid_three, out_three, kernel_size = 3, padding = 1),
            nn.ReLU()         
        )

        self.five = nn.Sequential(
            # 1 x 1 conv
            nn.Conv2d(in_features, mid_five, kernel_size = 1, padding = 0),
            nn.ReLU(),
            # 5 x 5 conv
            nn.Conv2d(mid_five, out_five, kernel_size = 5, padding = 2),
            nn.ReLU()          
        )

        self.max = nn.Sequential(
            # 3 x 3 max pool
            nn.MaxPool2d(kernel_size = 3, stride = 1, padding = 1), 
            # 1 x 1 conv
            nn.Conv2d(in_features, out_max, kernel_size = 1, padding = 0),
            nn.ReLU(),          
        )

    def forward(self, x):
        # Depth Concat
        y = torch.concat((self.one(x), self.three(x), self.five(x), self.max(x)), dim=1)
        return y

 

1x1의 out_features, 3x3의 middle(1x1과 3x3의 연결고리가 되는 n_features)과 out_features,

5x5의 middle과 out_features, 그리고 max_pool의 뒤에 붙는 1x1의 out_features를 설정한다.

 

그 다음 이번에는 4e로 input과 output 사이즈를 측정해서 올바르게 작성했는지 알아본다.

 

# Input shape of inceptiopn 4e
a = torch.rand(5, 528, 14, 14)

inception = Inception(528, 256, 160, 320, 32, 128, 128)

print("one:",inception.one(a).shape)
print("three:",inception.three(a).shape)
print("five:",inception.five(a).shape)
print("max pool:",inception.max(a).shape)

print("Depth concated:",inception(a).shape)

>>
one: torch.Size([5, 256, 14, 14])
three: torch.Size([5, 320, 14, 14])
five: torch.Size([5, 128, 14, 14])
max pool: torch.Size([5, 128, 14, 14])
Depth concated: torch.Size([5, 832, 14, 14])

 

Inception의 4d의 output size가 4e의 input size인 14 x 14 x 528이고 4e의 output size는 14 x 14 x 832다.

코드를 작성한 결과 5 x 832 x 14 x 14이므로 올바른 결과가 도출되었다.

이를 토대로 전체 GoogLeNet을 작성한다.

 

class GoogLeNet(nn.Module):
    def __init__(self, num_classes = 1000):
        super().__init__()

        # Inception Layers 선언 후 초기화
        self.inception_3a = Inception(192, 64, 96, 128, 16, 32, 32)
        self.inception_3b = Inception(256, 128, 128, 192, 32, 96, 64)

        self.inception_4a = Inception(480, 192, 96, 208, 16, 48, 64)
        self.inception_4b = Inception(512, 160, 112, 224, 24, 64, 64)
        self.inception_4c = Inception(512, 128, 128, 256, 24, 64, 64)
        self.inception_4d = Inception(512, 112, 144, 288, 32, 64, 64)
        self.inception_4e = Inception(528, 256, 160, 320, 32, 128, 128)

        self.inception_5a = Inception(832, 256, 160, 320, 32, 128, 128)
        self.inception_5b = Inception(832, 384, 192, 384, 48, 128, 128)
        
        # 첫 번째 Convolution
        self.conv_0 = nn.Sequential(
            # Extracting Fetures Part
            # First Conv
            nn.Conv2d(3, 64, kernel_size = 7, stride = 2, padding = 3),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 3, stride = 2),
            nn.BatchNorm2d(64),
        )
        # 두 번째 Convolution
        self.conv_1 = nn.Sequential(
            # Extracting Fetures Part
            # Second Conv
            nn.Conv2d(64, 64, kernel_size = 1, padding = 1),
            nn.ReLU(),
            nn.Conv2d(64, 192, kernel_size = 3, padding = 1),
            nn.ReLU(),
            nn.BatchNorm2d(192),
            nn.MaxPool2d(kernel_size = 3, stride = 2),
        )
        # main branch는 최종 추론을 위한 모델 구조
        self.main_branch_0 = nn.Sequential(
            # Extracting Fetures Part
            self.inception_3a,
            self.inception_3b,
            nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1),
            self.inception_4a
        )
        # side branch는 auxilary loss 계산을 위한 항목
        # Returning softmax0, auxilary softmax
        self.side_branch_0 = nn.Sequential(
            nn.AvgPool2d(kernel_size = 5, stride = 3),
            nn.Conv2d(512, 128, kernel_size = 1, stride = 1),
            nn.ReLU()
        )

        self.side_classifier_0 = nn.Sequential(
            nn.Linear(2048, 1024),
            nn.ReLU(),
            nn.Dropout(p = 0.7),
            nn.Linear(1024, num_classes),
            nn.Softmax(dim=0)
        )

 
        self.main_branch_1 = nn.Sequential(
            self.inception_4b,
            self.inception_4c,
            self.inception_4d
        )         

        # side branch는 auxilary loss 계산을 위한 항목
        # Returning softmax1, auxilary softmax
        self.side_branch_1 = nn.Sequential(
            nn.AvgPool2d(kernel_size = 5, stride = 3),
            nn.Conv2d(528, 128, kernel_size = 1, stride = 1),
            nn.ReLU()
        )

        self.side_classifier_1 = nn.Sequential(
            nn.Linear(2048, 1024),
            nn.ReLU(),
            nn.Dropout(p = 0.7),
            nn.Linear(1024, num_classes),
            nn.Softmax(dim=0)
        )        

        self.main_branch_2 = nn.Sequential(
            self.inception_4e,
            nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1),
            self.inception_5a,
            self.inception_5b,
            nn.AvgPool2d(kernel_size = 7)
        )   

        # Main classifier
        self.classifier = nn.Sequential(
            # Classification Part
            nn.Dropout(p=0.4),
            nn.Linear(1024, num_classes),
            nn.Softmax(dim=0)
        )

    def forward(self, x):
        y = self.conv_0(x)
        y = self.conv_1(y)
        y = self.main_branch_0(y)

        # auxilary side branch 0
        y_aux0 = self.side_branch_0(y)
        y_aux0 = torch.flatten(y_aux0, start_dim = 1)
        y_aux0 = self.side_classifier_0(y_aux0)
        
        y = self.main_branch_1(y)

        # auxilary side branch 1
        y_aux1 = self.side_branch_1(y)
        y_aux1 = torch.flatten(y_aux1, start_dim = 1)
        y_aux1 = self.side_classifier_1(y_aux1)
       
        # main branch 
        y = self.main_branch_2(y)
        y = torch.flatten(y, start_dim = 1)
        y = self.classifier(y)
        return y, y_aux0, y_aux1

 

위 코드는 아래와 같이 일일이 layers 마다 제대로 작동하는지 output size를 Table 1과 비교하며 검증했다.

첫 번째 convolution은 w x h의 shape가 다른데 아마 논문에 오류가 있지 않나 싶다.

 

 

 

model = GoogLeNet()

a = torch.rand(5, 3, 224, 224)
y_step0 = model.conv_0(a)
y_step0.shape
>> torch.Size([5, 64, 55, 55])

y_step0 = model.conv_0(a)
y_step0 = model.conv_1(y_step0)
y_step0.shape
>> torch.Size([5, 192, 28, 28])

y_step0 = model.conv_0(a)
y_step0 = model.conv_1(y_step0)
y_step0 = model.main_branch_0(y_step0)
y_step0.shape
>> torch.Size([5, 512, 14, 14])

# check axularary softmax 0
y_side0 = model.side_branch_0(y_step0)
y_aux0 = torch.flatten(y_side0, start_dim = 1)
y_side0 = model.side_classifier_0(y_aux0)
y_side0.shape
>> torch.Size([5, 1000])

# check softmax0, 1, and 2
y_main, y_aux_0, y_aux_1 = model(a)
y_main.shape, y_aux_0.shape, y_aux_1.shape
>> (torch.Size([5, 1000]), torch.Size([5, 1000]), torch.Size([5, 1000]))

 

최종적으로는 3개의 softmax에서 b x 1000의 크기가 리턴되어 올바르게 작성했다고 볼 수 있다.

 

 

GoogLeNet from torchvision

 

torchvision에서 제공하는 GoogLeNet 코드를 통해 직접 구현한 코드를 비교하고자 한다.

 

 

import torchvision.models as models
googlenet = models.googlenet(pretrained=False)
googlenet

>>
GoogLeNet(
  (conv1): BasicConv2d(
    (conv): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (maxpool1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
  (conv2): BasicConv2d(
    (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (conv3): BasicConv2d(
    (conv): Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (maxpool2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
  (inception3a): Inception(
    (branch1): BasicConv2d(
      (conv): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch2): Sequential(
      (0): BasicConv2d(
        (conv): Conv2d(192, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(96, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicConv2d(
        (conv): Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (branch3): Sequential(
      (0): BasicConv2d(
        (conv): Conv2d(192, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(16, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicConv2d(
        (conv): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (branch4): Sequential(
      (0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=True)
      (1): BasicConv2d(
        (conv): Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
  )
  (inception3b): Inception(
    (branch1): BasicConv2d(
      (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch2): Sequential(
      (0): BasicConv2d(
        (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicConv2d(
        (conv): Conv2d(128, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (branch3): Sequential(
      (0): BasicConv2d(
        (conv): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicConv2d(
        (conv): Conv2d(32, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(96, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (branch4): Sequential(
      (0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=True)
      (1): BasicConv2d(
        (conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
  )
  (maxpool3): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
  (inception4a): Inception(
    (branch1): BasicConv2d(
      (conv): Conv2d(480, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch2): Sequential(
      (0): BasicConv2d(
        (conv): Conv2d(480, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(96, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicConv2d(
        (conv): Conv2d(96, 208, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(208, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (branch3): Sequential(
      (0): BasicConv2d(
        (conv): Conv2d(480, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(16, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicConv2d(
        (conv): Conv2d(16, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(48, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (branch4): Sequential(
      (0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=True)
      (1): BasicConv2d(
        (conv): Conv2d(480, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
  )
  (inception4b): Inception(
    (branch1): BasicConv2d(
      (conv): Conv2d(512, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch2): Sequential(
      (0): BasicConv2d(
        (conv): Conv2d(512, 112, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(112, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicConv2d(
        (conv): Conv2d(112, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(224, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (branch3): Sequential(
      (0): BasicConv2d(
        (conv): Conv2d(512, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicConv2d(
        (conv): Conv2d(24, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (branch4): Sequential(
      (0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=True)
      (1): BasicConv2d(
        (conv): Conv2d(512, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
  )
  (inception4c): Inception(
    (branch1): BasicConv2d(
      (conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch2): Sequential(
      (0): BasicConv2d(
        (conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicConv2d(
        (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (branch3): Sequential(
      (0): BasicConv2d(
        (conv): Conv2d(512, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicConv2d(
        (conv): Conv2d(24, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (branch4): Sequential(
      (0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=True)
      (1): BasicConv2d(
        (conv): Conv2d(512, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
  )
  (inception4d): Inception(
    (branch1): BasicConv2d(
      (conv): Conv2d(512, 112, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(112, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch2): Sequential(
      (0): BasicConv2d(
        (conv): Conv2d(512, 144, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(144, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicConv2d(
        (conv): Conv2d(144, 288, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(288, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (branch3): Sequential(
      (0): BasicConv2d(
        (conv): Conv2d(512, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicConv2d(
        (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (branch4): Sequential(
      (0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=True)
      (1): BasicConv2d(
        (conv): Conv2d(512, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
  )
  (inception4e): Inception(
    (branch1): BasicConv2d(
      (conv): Conv2d(528, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch2): Sequential(
      (0): BasicConv2d(
        (conv): Conv2d(528, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicConv2d(
        (conv): Conv2d(160, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(320, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (branch3): Sequential(
      (0): BasicConv2d(
        (conv): Conv2d(528, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicConv2d(
        (conv): Conv2d(32, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (branch4): Sequential(
      (0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=True)
      (1): BasicConv2d(
        (conv): Conv2d(528, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
  )
  (maxpool4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=True)
  (inception5a): Inception(
    (branch1): BasicConv2d(
      (conv): Conv2d(832, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(256, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch2): Sequential(
      (0): BasicConv2d(
        (conv): Conv2d(832, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(160, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicConv2d(
        (conv): Conv2d(160, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(320, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (branch3): Sequential(
      (0): BasicConv2d(
        (conv): Conv2d(832, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicConv2d(
        (conv): Conv2d(32, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (branch4): Sequential(
      (0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=True)
      (1): BasicConv2d(
        (conv): Conv2d(832, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
  )
  (inception5b): Inception(
    (branch1): BasicConv2d(
      (conv): Conv2d(832, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch2): Sequential(
      (0): BasicConv2d(
        (conv): Conv2d(832, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicConv2d(
        (conv): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(384, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (branch3): Sequential(
      (0): BasicConv2d(
        (conv): Conv2d(832, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(48, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicConv2d(
        (conv): Conv2d(48, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (branch4): Sequential(
      (0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=True)
      (1): BasicConv2d(
        (conv): Conv2d(832, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
  )
  (aux1): InceptionAux(
    (conv): BasicConv2d(
      (conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (fc1): Linear(in_features=2048, out_features=1024, bias=True)
    (fc2): Linear(in_features=1024, out_features=1000, bias=True)
    (dropout): Dropout(p=0.7, inplace=False)
  )
  (aux2): InceptionAux(
    (conv): BasicConv2d(
      (conv): Conv2d(528, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(128, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (fc1): Linear(in_features=2048, out_features=1024, bias=True)
    (fc2): Linear(in_features=1024, out_features=1000, bias=True)
    (dropout): Dropout(p=0.7, inplace=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (dropout): Dropout(p=0.2, inplace=False)
  (fc): Linear(in_features=1024, out_features=1000, bias=True)
)

 

몇 가지 차이점 찾을 수 있다.

우선 Conv2d 마다 bias가 False로 설정되어 있다.

두 번째로는 BatchNorm2d에서 epsilon의 옵션이 디폴트인 1e-05와 다름을 알 수 있다. 다른 옵션인 momentum과 affine은 pytorch 2.0.1 버젼에서는 모두 디폴트 값인 0.1과 True로 동일한데 명시적으로 적었다.

세 번째로는 Average Pooling을 AvgPool2d가 아닌 AdaptiveAvgPool2d로 놓았음을 알 수 있다. 이 역시 이미지의 크기가 달라도 학습가능하게끔 강제로 GoogLeNet 구조에 맞추도록 강제하는 옵션이라고 볼 수 있다.

 

또한 inception 내부의 항목을 보면 본인처럼 나누지 않고 branch 1, 2, ..., 4로 명명했음을 알 수 있다.

그리고 ReLU가 적용되지 않았고, 개별 branch 마다 모두 batch normalization이 적용되었음을 알 수 있다.

하지만 ReLU는 실제 코드상에서는 적용되었고 단지 print 되지 않음을 아래 pytorch discussion에서 알 수 있었다.

 

Batch Normalization은 추후에 나온 Inception v2, v3 등에서 나온 항목인데 tochvision에 적용된듯 싶다.

 

 

 

References:

https://velog.io/@euisuk-chung/%ED%8C%8C%EC%9D%B4%ED%86%A0%EC%B9%98-%ED%8C%8C%EC%9D%B4%ED%86%A0%EC%B9%98%EB%A1%9C-CNN-%EB%AA%A8%EB%8D%B8%EC%9D%84-%EA%B5%AC%ED%98%84%ED%95%B4%EB%B3%B4%EC%9E%90-GoogleNet%ED%8E%B8

https://datascience.stackexchange.com/questions/67064/what-doese-v-mean-in-googlenet

https://discuss.pytorch.org/t/absent-relu-layers-in-pretrained-googlenet/169530

https://github.com/pytorch/vision/blob/32d254bbfcf14975f846765775584e61ef25a5bc/torchvision/models/googlenet.py#L184-L275

https://sh-tsang.medium.com/review-batch-normalization-inception-v2-bn-inception-the-2nd-to-surpass-human-level-18e2d0f56651

https://hyunsooworld.tistory.com/entry/Inception-v1v2v3v4%EB%8A%94-%EB%AC%B4%EC%97%87%EC%9D%B4-%EB%8B%A4%EB%A5%B8%EA%B0%80-CNN%EC%9D%98-%EC%97%AD%EC%82%AC

 

 

 

 

 

'Computer Vision' 카테고리의 다른 글

R-CNN (2014) PyTorch Implementation  (0) 2024.04.08
R-CNN (2014) 논문 리뷰  (0) 2024.04.04
GoogLeNet = Inception v1 (2014) 논문 리뷰  (0) 2024.04.03
VGGNet PyTorch Code Implementation  (0) 2024.04.02
VGGNet (2014) 논문 리뷰  (0) 2024.04.01