close

下圖為VGG系列的結構表

VGG的結構其實是由AlexNet演變而來,

VGG原文參考:

https://arxiv.org/abs/1409.1556

 

AlexNet是由5層Convolution Layer + 3層Fully-connected Layer而組成,

AlexNet結構設計可參考如下:

https://jennaweng0621.pixnet.net/blog/post/403588460

 

下面為VGG系列結構的第一種寫法:

import torch.nn as nn

cfg = { #M表示為MaxPool2d, 其他數值表示為各Convolution輸出的深度

       'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],

       'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],

       'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],

       'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M',512, 512, 512, 512, 'M']

       }

 

class VGG(nn.Module):

    def __init__(self, vgg_name):

        super(VGG, self).__init__()

        self.features = self._make_layers(cfg[vgg_name])

        self.classifier = nn.Linear(512, 10)

 

    def forward(self, x):

        out = self.features(x)

        out = out.view(out.size(0), -1)

        out = self.classifier(out)

        return out

 

    def _make_layers(self, cfg):

        layers = []

        in_channels = 3

        for x in cfg:

            if x == 'M':

                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]

            else:

                layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),

                           nn.BatchNorm2d(x),

                           nn.ReLU(inplace=True)] #預設為False,表示新建一個對象對其修改, True則表示直接對這個對象進行修改

                in_channels = x

        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]

        return nn.Sequential(*layers)

 

#取得VGG系列的model

vgg11 = VGG('VGG11')

vgg13 = VGG('VGG13')

vgg16 = VGG('VGG16')

vgg19 = VGG('VGG19')

 

下面為VGG11輸出結果:

 

下面為VGG系列結構的第二種寫法(官方預設結構):

import torchvision.models as models

vgg11 = models.vgg11()

vgg13 = models.vgg13()

vgg16 = models.vgg16()

vgg19 = models.vgg19()

可以導入預訓練的model權重, 只要在()內填入pretrained參數即可, 如下:

vgg11 = models.vgg11(pretrained = True)

下面為VGG11輸出結果:

arrow
arrow
    創作者介紹
    創作者 楓綺 的頭像
    楓綺

    K_程式人

    楓綺 發表在 痞客邦 留言(0) 人氣()