pytorch三层全连接层实现手写字母识别方式


Posted in Python onJanuary 14, 2020

先用最简单的三层全连接神经网络,然后添加激活层查看实验结果,最后加上批标准化验证是否有效

首先根据已有的模板定义网络结构SimpleNet,命名为net.py

import torch
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
from torch import nn,optim
from torch.utils.data import DataLoader
from torchvision import datasets,transforms
#定义三层全连接神经网络
class simpleNet(nn.Module):
 def __init__(self,in_dim,n_hidden_1,n_hidden_2,out_dim):#输入维度,第一层的神经元个数、第二层的神经元个数,以及第三层的神经元个数
  super(simpleNet,self).__init__()
  self.layer1=nn.Linear(in_dim,n_hidden_1)
  self.layer2=nn.Linear(n_hidden_1,n_hidden_2)
  self.layer3=nn.Linear(n_hidden_2,out_dim)
 def forward(self,x):
  x=self.layer1(x)
  x=self.layer2(x)
  x=self.layer3(x)
  return x
 
 
#添加激活函数
class Activation_Net(nn.Module):
 def __init__(self,in_dim,n_hidden_1,n_hidden_2,out_dim):
  super(NeutalNetwork,self).__init__()
  self.layer1=nn.Sequential(#Sequential组合结构
  nn.Linear(in_dim,n_hidden_1),nn.ReLU(True))
  self.layer2=nn.Sequential(
  nn.Linear(n_hidden_1,n_hidden_2),nn.ReLU(True))
  self.layer3=nn.Sequential(
  nn.Linear(n_hidden_2,out_dim))
 def forward(self,x):
  x=self.layer1(x)
  x=self.layer2(x)
  x=self.layer3(x)
  return x
#添加批标准化处理模块,皮标准化放在全连接的后面,非线性的前面
class Batch_Net(nn.Module):
 def _init__(self,in_dim,n_hidden_1,n_hidden_2,out_dim):
  super(Batch_net,self).__init__()
  self.layer1=nn.Sequential(nn.Linear(in_dim,n_hidden_1),nn.BatchNormld(n_hidden_1),nn.ReLU(True))
  self.layer2=nn.Sequential(nn.Linear(n_hidden_1,n_hidden_2),nn.BatchNormld(n_hidden_2),nn.ReLU(True))
  self.layer3=nn.Sequential(nn.Linear(n_hidden_2,out_dim))
 def forword(self,x):
  x=self.layer1(x)
  x=self.layer2(x)
  x=self.layer3(x)
  return x

训练网络,

import torch
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from torch import nn,optim
from torch.utils.data import DataLoader
from torchvision import datasets,transforms
#定义一些超参数
import net
batch_size=64
learning_rate=1e-2
num_epoches=20
#预处理
data_tf=transforms.Compose(
[transforms.ToTensor(),transforms.Normalize([0.5],[0.5])])#将图像转化成tensor,然后继续标准化,就是减均值,除以方差

#读取数据集
train_dataset=datasets.MNIST(root='./data',train=True,transform=data_tf,download=True)
test_dataset=datasets.MNIST(root='./data',train=False,transform=data_tf)
#使用内置的函数导入数据集
train_loader=DataLoader(train_dataset,batch_size=batch_size,shuffle=True)
test_loader=DataLoader(test_dataset,batch_size=batch_size,shuffle=False)

#导入网络,定义损失函数和优化方法
model=net.simpleNet(28*28,300,100,10)
if torch.cuda.is_available():#是否使用cuda加速
 model=model.cuda()
criterion=nn.CrossEntropyLoss()
optimizer=optim.SGD(model.parameters(),lr=learning_rate)
import net
n_epochs=5
for epoch in range(n_epochs):
 running_loss=0.0
 running_correct=0
 print("epoch {}/{}".format(epoch,n_epochs))
 print("-"*10)
 for data in train_loader:
  img,label=data
  img=img.view(img.size(0),-1)
  if torch.cuda.is_available():
   img=img.cuda()
   label=label.cuda()
  else:
   img=Variable(img)
   label=Variable(label)
  out=model(img)#得到前向传播的结果
  loss=criterion(out,label)#得到损失函数
  print_loss=loss.data.item()
  optimizer.zero_grad()#归0梯度
  loss.backward()#反向传播
  optimizer.step()#优化
  running_loss+=loss.item()
  epoch+=1
  if epoch%50==0:
   print('epoch:{},loss:{:.4f}'.format(epoch,loss.data.item()))

训练的结果截图如下:

pytorch三层全连接层实现手写字母识别方式

测试网络

#测试网络
model.eval()#将模型变成测试模式
eval_loss=0
eval_acc=0
for data in test_loader:
 img,label=data
 img=img.view(img.size(0),-1)#测试集不需要反向传播,所以可以在前项传播的时候释放内存,节约内存空间
 if torch.cuda.is_available():
  img=Variable(img,volatile=True).cuda()
  label=Variable(label,volatile=True).cuda()
 else:
  img=Variable(img,volatile=True)
  label=Variable(label,volatile=True)
 out=model(img)
 loss=criterion(out,label)
 eval_loss+=loss.item()*label.size(0)
 _,pred=torch.max(out,1)
 num_correct=(pred==label).sum()
 eval_acc+=num_correct.item()
print('test loss:{:.6f},ac:{:.6f}'.format(eval_loss/(len(test_dataset)),eval_acc/(len(test_dataset))))

pytorch三层全连接层实现手写字母识别方式

训练的时候,还可以加入一些dropout,正则化,修改隐藏层神经元的个数,增加隐藏层数,可以自己添加。

以上这篇pytorch三层全连接层实现手写字母识别方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
videocapture库制作python视频高速传输程序
Dec 23 Python
分享一个常用的Python模拟登陆类
Mar 29 Python
python实现单向链表详解
Feb 08 Python
Python中将变量按行写入txt文本中的方法
Apr 03 Python
Python编程中NotImplementedError的使用方法
Apr 21 Python
深入理解Django-Signals信号量
Feb 19 Python
关于Python3 类方法、静态方法新解
Aug 30 Python
python 发送json数据操作实例分析
Oct 15 Python
django连接mysql数据库及建表操作实例详解
Dec 10 Python
Python3以GitHub为例来实现模拟登录和爬取的实例讲解
Jul 30 Python
python字典按照value排序方法
Dec 28 Python
手把手教你实现PyTorch的MNIST数据集
Jun 28 Python
Python实现bilibili时间长度查询的示例代码
Jan 14 #Python
基于python监控程序是否关闭
Jan 14 #Python
关于pytorch中全连接神经网络搭建两种模式详解
Jan 14 #Python
使用Pytorch来拟合函数方式
Jan 14 #Python
pytorch 模拟关系拟合——回归实例
Jan 14 #Python
PyTorch实现AlexNet示例
Jan 14 #Python
Pytorch 实现focal_loss 多类别和二分类示例
Jan 14 #Python
You might like
咖啡知识大全
2021/03/03 新手入门
全文搜索和替换
2006/10/09 PHP
PHP调用MySQL的存储过程的实现代码
2008/08/12 PHP
PHP 文件缓存的性能测试
2010/04/25 PHP
PHP的5个安全措施小结
2012/07/17 PHP
PHP中mysqli_affected_rows作用行数返回值分析
2014/12/26 PHP
PHP的Yii框架中使用数据库的配置和SQL操作实例教程
2016/03/17 PHP
PHP5.6.8连接SQL Server 2008 R2数据库常用技巧分析总结
2019/05/06 PHP
js之WEB开发调试利器:Firebug 下载
2007/01/13 Javascript
JQuery 无废话系列教程(一) jquery入门 [推荐]
2009/06/23 Javascript
倒记时60刷新网页的js代码
2014/02/18 Javascript
什么是Node.js?Node.js详细介绍
2014/06/01 Javascript
javascript基于HTML5 canvas制作画箭头组件
2014/06/25 Javascript
在JavaScript中重写jQuery对象的方法实例教程
2014/08/25 Javascript
javascript自定义右键弹出菜单实现方法
2015/05/25 Javascript
javascript获取wx.config内部字段解决微信分享
2016/03/09 Javascript
使用jQuery UI库开发Web界面的简单入门指引
2016/04/22 Javascript
浅析Bootstrap组件之面板组件
2016/05/04 Javascript
javascript 动态生成css代码的两种方法
2017/03/17 Javascript
玩转Koa之koa-router原理解析
2018/12/29 Javascript
vue的滚动条插件实现代码
2019/09/07 Javascript
JavaScript中0、空字符串、'0'是true还是false的知识点分享
2019/09/16 Javascript
js实现点赞按钮功能的实例代码
2020/03/06 Javascript
vue+vuex+axios从后台获取数据存入vuex,组件之间共享数据操作
2020/07/31 Javascript
python实现监控linux性能及进程消耗性能的方法
2014/07/25 Python
python中split方法用法分析
2015/04/17 Python
在Linux命令行终端中使用python的简单方法(推荐)
2017/01/23 Python
Python之web模板应用
2017/12/26 Python
Django中多种重定向方法使用详解
2019/07/17 Python
聊聊python在linux下与windows下导入模块的区别说明
2021/03/03 Python
英国行业制服供应商:Alexandra
2019/09/14 全球购物
学生评语大全
2014/04/18 职场文书
学习三严三实对照检查材料思想汇报
2014/09/22 职场文书
民主评议教师党员自我评价
2015/03/04 职场文书
mybatis中sql语句CDATA标签的用法说明
2021/06/30 Java/Android
利用Java连接Hadoop进行编程
2022/06/28 Java/Android