如何搭建VGG网络实现Mnist数据集的图像分类

电子说

1.3w人已加入

描述

1 问题

如何搭建VGG网络,实现Mnist数据集的图像分类?

2 方法

步骤:

首先导包

VGG11由8个卷积,三个全连接组成,注意池化只改变特征图大小,不改变通道数

给定x查看最后结果

x = torch.rand(128,3,224,224)
net = MyNet()
out = net(x)
print(out.shape)
#torch.Size([128, 1000])

class MyNet(nn.Module):
   def __init__(self) -> None:
       super().__init__()
       #(1)conv3-64
       self.conv1 = nn.Conv2d(
           in_channels=3,
           out_channels=64,
           kernel_size=3,
           stride=1,
           padding=1 #! 不改变特征图的大小
       )
       #! 池化只改变特征图大小,不改变通道数
       self.max_pool_1 = nn.MaxPool2d(2)
       #(2)conv3-128
       self.conv2 = nn.Conv2d(
           in_channels=64,
           out_channels=128,
           kernel_size=3,
           stride=1,
           padding=1
       )
       self.max_pool_2 = nn.MaxPool2d(2)
       #(3) conv3-256,conv3-256
       self.conv3_1 = nn.Conv2d(
           in_channels=128,
           out_channels=256,
           kernel_size=3,
           stride=1,
           padding=1)
       self.conv3_2 = nn.Conv2d(
           in_channels=256,
           out_channels=256,
           kernel_size=3,
           stride=1,
           padding=1
       )
       self.max_pool_3 = nn.MaxPool2d(2)
       #(4)conv3-512,conv3-512
       self.conv4_1 = nn.Conv2d(
           in_channels=256,
           out_channels=512,
           kernel_size=3,
           stride=1,
           padding=1
       )
       self.conv4_2 = nn.Conv2d(
           in_channels=512,
           out_channels=512,
           kernel_size=3,
           stride=1,
           padding=1
       )
       self.max_pool_4 = nn.MaxPool2d(2)
       #(5)conv3-512,conv3-512
       self.conv5_1 = nn.Conv2d(
           in_channels=512,
           out_channels=512,
           kernel_size=3,
           stride=1,
           padding=1
       )
       self.conv5_2 = nn.Conv2d(
           in_channels=512,
           out_channels=512,
           kernel_size=3,
           stride=1,
           padding=1
       )
       self.max_pool_5 = nn.MaxPool2d(2)
       #(6)
       self.fc1 = nn.Linear(25088,4096)
       self.fc2 = nn.Linear(4096,4096)
       self.fc3 = nn.Linear(4096,1000)
   def forward(self,x):
       x = self.conv1(x)
       print(x.shape)
       x = self.max_pool_1(x)
       print(x.shape)
       x = self.conv2(x)
       print(x.shape)
       x = self.max_pool_2(x)
       print(x.shape)
       x = self.conv3_1(x)
       print(x.shape)
       x = self.conv3_2(x)
       print(x.shape)
       x = self.max_pool_3(x)
       print(x.shape)
       x = self.conv4_1(x)
       print(x.shape)
       x = self.conv4_2(x)
       print(x.shape)
       x = self.max_pool_4(x)
       print(x.shape)
       x = self.conv5_1(x)
       print(x.shape)
       x = self.conv5_2(x)
       print(x.shape)
       x = self.max_pool_5(x)
       print(x.shape)
       x = torch.flatten(x,1)
       print(x.shape)
       x = self.fc1(x)
       print(x.shape)
       x = self.fc2(x)
       print(x.shape)
       out = self.fc3(x)
       return out

Import torch
from torch import nn

 

3 结语

   通过本周学习让我学会了VGG11网络,从实验中我遇到的容易出错的地方是卷积的in_features和out_features容易出错,尺寸不对的时候就会报错,在多个卷积的情况下尤其需要注意,第二点容易出错的地方是卷积以及池化所有结束后,一定要使用torch.flatten进行拉伸,第三点容易出错的地方是fc1的in_features,这个我通过使用断点的方法,得到fc1前一步的size值,从而得到in_features的值,从中收获颇深。

打开APP阅读更多精彩内容
声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉

全部0条评论

快来发表一下你的评论吧 !

×
20
完善资料,
赚取积分