Pytorch 实现VGGNet

Pytorch实践

Posted by WEW on April 26, 2019

VGG的网络结构

如下图所示

使用Pytorch框架构建VGG16

# 定义VGG16网络
import torch
import pdb
import torch.nn as nn
import torch.nn.functional as F
chanels=[64,64,'MP',128,128,'MP',256,256,256,'MP',512,512,512,'MP',512,512,512,'MP3','FC','FC1','FC2']
class VGGNet(nn.Module):
    #定义构造函数
    def __init__(self,Chanels,num_class):
        super(VGGNet,self).__init__()
        self.num_class=num_class
        layers=[]
        input_chanel=3
        for ar in Chanels:
           if(ar=='MP'):
               layers.append(nn.MaxPool2d(kernel_size=2,stride=2))
           elif(ar=='MP3'):
               layers.append(nn.MaxPool2d(3,stride=1,padding=1))
           elif(ar=='FC'):
              layers.append(nn.Conv2d(512,1024,3,padding=6,dilation=6))
             layers.append(nn.ReLU(inplace=True))
          elif(ar=='FC1'):
               layers.append(nn.Conv2d(1024,1024,1))
               layers.append(nn.ReLU(inplace=True))
          elif(ar=='FC2'):
               layers.append(nn.Conv2d(1024,self.num_class,1))
          else:
               layers.append(nn.Conv2d(in_channels=input_chanel,out_channels=ar,kernel_size=3,padding=1))
               layers.append(nn.ReLU(inplace=True))
               input_chanel=ar
      self.Layers=nn.ModuleList(layers)
 def forward(self,x):
      for lays in self.Layers:
          x=lays(x)
      out=x
      return out
inputs=torch.randn(1,3,300,300)
vgg16=VGGNet(chanels,10)
output=vgg16(inputs)
print(output.shape)