pytorch三层全连接层实现手写字母识别方式


Posted in Python onJanuary 14, 2020

先用最简单的三层全连接神经网络,然后添加激活层查看实验结果,最后加上批标准化验证是否有效

首先根据已有的模板定义网络结构SimpleNet,命名为net.py

import torch
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
from torch import nn,optim
from torch.utils.data import DataLoader
from torchvision import datasets,transforms
#定义三层全连接神经网络
class simpleNet(nn.Module):
 def __init__(self,in_dim,n_hidden_1,n_hidden_2,out_dim):#输入维度,第一层的神经元个数、第二层的神经元个数,以及第三层的神经元个数
  super(simpleNet,self).__init__()
  self.layer1=nn.Linear(in_dim,n_hidden_1)
  self.layer2=nn.Linear(n_hidden_1,n_hidden_2)
  self.layer3=nn.Linear(n_hidden_2,out_dim)
 def forward(self,x):
  x=self.layer1(x)
  x=self.layer2(x)
  x=self.layer3(x)
  return x
 
 
#添加激活函数
class Activation_Net(nn.Module):
 def __init__(self,in_dim,n_hidden_1,n_hidden_2,out_dim):
  super(NeutalNetwork,self).__init__()
  self.layer1=nn.Sequential(#Sequential组合结构
  nn.Linear(in_dim,n_hidden_1),nn.ReLU(True))
  self.layer2=nn.Sequential(
  nn.Linear(n_hidden_1,n_hidden_2),nn.ReLU(True))
  self.layer3=nn.Sequential(
  nn.Linear(n_hidden_2,out_dim))
 def forward(self,x):
  x=self.layer1(x)
  x=self.layer2(x)
  x=self.layer3(x)
  return x
#添加批标准化处理模块,皮标准化放在全连接的后面,非线性的前面
class Batch_Net(nn.Module):
 def _init__(self,in_dim,n_hidden_1,n_hidden_2,out_dim):
  super(Batch_net,self).__init__()
  self.layer1=nn.Sequential(nn.Linear(in_dim,n_hidden_1),nn.BatchNormld(n_hidden_1),nn.ReLU(True))
  self.layer2=nn.Sequential(nn.Linear(n_hidden_1,n_hidden_2),nn.BatchNormld(n_hidden_2),nn.ReLU(True))
  self.layer3=nn.Sequential(nn.Linear(n_hidden_2,out_dim))
 def forword(self,x):
  x=self.layer1(x)
  x=self.layer2(x)
  x=self.layer3(x)
  return x

训练网络,

import torch
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from torch import nn,optim
from torch.utils.data import DataLoader
from torchvision import datasets,transforms
#定义一些超参数
import net
batch_size=64
learning_rate=1e-2
num_epoches=20
#预处理
data_tf=transforms.Compose(
[transforms.ToTensor(),transforms.Normalize([0.5],[0.5])])#将图像转化成tensor,然后继续标准化,就是减均值,除以方差

#读取数据集
train_dataset=datasets.MNIST(root='./data',train=True,transform=data_tf,download=True)
test_dataset=datasets.MNIST(root='./data',train=False,transform=data_tf)
#使用内置的函数导入数据集
train_loader=DataLoader(train_dataset,batch_size=batch_size,shuffle=True)
test_loader=DataLoader(test_dataset,batch_size=batch_size,shuffle=False)

#导入网络,定义损失函数和优化方法
model=net.simpleNet(28*28,300,100,10)
if torch.cuda.is_available():#是否使用cuda加速
 model=model.cuda()
criterion=nn.CrossEntropyLoss()
optimizer=optim.SGD(model.parameters(),lr=learning_rate)
import net
n_epochs=5
for epoch in range(n_epochs):
 running_loss=0.0
 running_correct=0
 print("epoch {}/{}".format(epoch,n_epochs))
 print("-"*10)
 for data in train_loader:
  img,label=data
  img=img.view(img.size(0),-1)
  if torch.cuda.is_available():
   img=img.cuda()
   label=label.cuda()
  else:
   img=Variable(img)
   label=Variable(label)
  out=model(img)#得到前向传播的结果
  loss=criterion(out,label)#得到损失函数
  print_loss=loss.data.item()
  optimizer.zero_grad()#归0梯度
  loss.backward()#反向传播
  optimizer.step()#优化
  running_loss+=loss.item()
  epoch+=1
  if epoch%50==0:
   print('epoch:{},loss:{:.4f}'.format(epoch,loss.data.item()))

训练的结果截图如下:

pytorch三层全连接层实现手写字母识别方式

测试网络

#测试网络
model.eval()#将模型变成测试模式
eval_loss=0
eval_acc=0
for data in test_loader:
 img,label=data
 img=img.view(img.size(0),-1)#测试集不需要反向传播,所以可以在前项传播的时候释放内存,节约内存空间
 if torch.cuda.is_available():
  img=Variable(img,volatile=True).cuda()
  label=Variable(label,volatile=True).cuda()
 else:
  img=Variable(img,volatile=True)
  label=Variable(label,volatile=True)
 out=model(img)
 loss=criterion(out,label)
 eval_loss+=loss.item()*label.size(0)
 _,pred=torch.max(out,1)
 num_correct=(pred==label).sum()
 eval_acc+=num_correct.item()
print('test loss:{:.6f},ac:{:.6f}'.format(eval_loss/(len(test_dataset)),eval_acc/(len(test_dataset))))

pytorch三层全连接层实现手写字母识别方式

训练的时候,还可以加入一些dropout,正则化,修改隐藏层神经元的个数,增加隐藏层数,可以自己添加。

以上这篇pytorch三层全连接层实现手写字母识别方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python类的动态修改的实例方法
Mar 24 Python
深入理解Python3中的http.client模块
Mar 29 Python
详解Python 2.6 升级至 Python 2.7 的实践心得
Apr 27 Python
python中使用zip函数出现错误的原因
Sep 28 Python
python 处理数字,把大于上限的数字置零实现方法
Jan 28 Python
python使用装饰器作日志处理的方法
Jul 11 Python
Python中字典与恒等运算符的用法分析
Aug 22 Python
使用Python爬虫库requests发送表单数据和JSON数据
Jan 25 Python
基于python实现百度语音识别和图灵对话
Nov 02 Python
Pycharm制作搞怪弹窗的实现代码
Feb 19 Python
Python中使用Lambda函数的5种用法
Apr 01 Python
python使用PySimpleGUI设置进度条及控件使用
Jun 10 Python
Python实现bilibili时间长度查询的示例代码
Jan 14 #Python
基于python监控程序是否关闭
Jan 14 #Python
关于pytorch中全连接神经网络搭建两种模式详解
Jan 14 #Python
使用Pytorch来拟合函数方式
Jan 14 #Python
pytorch 模拟关系拟合——回归实例
Jan 14 #Python
PyTorch实现AlexNet示例
Jan 14 #Python
Pytorch 实现focal_loss 多类别和二分类示例
Jan 14 #Python
You might like
PHP4 与 MySQL 数据库操作函数详解
2006/12/06 PHP
PHP导出MySQL数据到Excel文件(fputcsv)
2011/07/03 PHP
php通过分类列表产生分类树数组的方法
2015/04/20 PHP
php简单压缩css样式示例
2016/09/22 PHP
PHP微信企业号开发之回调模式开启与用法示例
2017/11/25 PHP
如何设置一定时间内只能发送一次请求
2014/02/28 Javascript
js拼接html注意问题示例探讨
2014/07/14 Javascript
充分发挥Node.js程序性能的一些方法介绍
2015/06/23 Javascript
jQuery实现的简单折叠菜单(折叠面板)效果代码
2015/09/16 Javascript
JS响应鼠标点击实现两个滑块区间拖动效果
2015/10/26 Javascript
基于JavaScript实现图片点击弹出窗口而不是保存
2016/02/06 Javascript
实用jquery操作表单元素的简单代码
2016/07/04 Javascript
ES6新特性之类(Class)和继承(Extends)相关概念与用法分析
2017/05/24 Javascript
Vue脚手架的简单使用实例
2018/07/10 Javascript
vue集成百度UEditor富文本编辑器使用教程
2018/09/21 Javascript
javascript利用键盘控制小方块的移动
2020/04/20 Javascript
Python计算回文数的方法
2015/03/11 Python
Python安装使用命令行交互模块pexpect的基础教程
2016/05/12 Python
Python实现对百度云的文件上传(实例讲解)
2017/10/21 Python
Django中针对基于类的视图添加csrf_exempt实例代码
2018/02/11 Python
使用python进行拆分大文件的方法
2018/12/10 Python
python实现websocket的客户端压力测试
2019/06/25 Python
Python RabbitMQ实现简单的进程间通信示例
2020/07/02 Python
python GUI计算器的实现
2020/10/09 Python
Pandas中两个dataframe的交集和差集的示例代码
2020/12/13 Python
Bally巴利中国官网:经典瑞士鞋履、手袋及配饰奢侈品牌
2018/10/09 全球购物
Otticanet英国:最顶尖的世界名牌眼镜, 能得到打折季的价格
2019/02/10 全球购物
迪拜领先运动补剂零售品牌中文站:Sporter商城
2019/08/20 全球购物
中等生评语大全
2014/05/04 职场文书
花田少年史观后感
2015/06/16 职场文书
同学聚会感言一句话
2015/07/30 职场文书
2016年领导干部廉政承诺书
2016/03/24 职场文书
市直属机关2016年主题党日活动总结
2016/04/05 职场文书
关于Vue中的options选项
2022/03/22 Vue.js
vue修饰符.capture和.self的区别
2022/04/22 Vue.js
CSS元素定位之通过元素的标签或者元素的id、class属性定位详解
2022/09/23 HTML / CSS