pytorch三层全连接层实现手写字母识别方式


Posted in Python onJanuary 14, 2020

先用最简单的三层全连接神经网络,然后添加激活层查看实验结果,最后加上批标准化验证是否有效

首先根据已有的模板定义网络结构SimpleNet,命名为net.py

import torch
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
from torch import nn,optim
from torch.utils.data import DataLoader
from torchvision import datasets,transforms
#定义三层全连接神经网络
class simpleNet(nn.Module):
 def __init__(self,in_dim,n_hidden_1,n_hidden_2,out_dim):#输入维度,第一层的神经元个数、第二层的神经元个数,以及第三层的神经元个数
  super(simpleNet,self).__init__()
  self.layer1=nn.Linear(in_dim,n_hidden_1)
  self.layer2=nn.Linear(n_hidden_1,n_hidden_2)
  self.layer3=nn.Linear(n_hidden_2,out_dim)
 def forward(self,x):
  x=self.layer1(x)
  x=self.layer2(x)
  x=self.layer3(x)
  return x
 
 
#添加激活函数
class Activation_Net(nn.Module):
 def __init__(self,in_dim,n_hidden_1,n_hidden_2,out_dim):
  super(NeutalNetwork,self).__init__()
  self.layer1=nn.Sequential(#Sequential组合结构
  nn.Linear(in_dim,n_hidden_1),nn.ReLU(True))
  self.layer2=nn.Sequential(
  nn.Linear(n_hidden_1,n_hidden_2),nn.ReLU(True))
  self.layer3=nn.Sequential(
  nn.Linear(n_hidden_2,out_dim))
 def forward(self,x):
  x=self.layer1(x)
  x=self.layer2(x)
  x=self.layer3(x)
  return x
#添加批标准化处理模块,皮标准化放在全连接的后面,非线性的前面
class Batch_Net(nn.Module):
 def _init__(self,in_dim,n_hidden_1,n_hidden_2,out_dim):
  super(Batch_net,self).__init__()
  self.layer1=nn.Sequential(nn.Linear(in_dim,n_hidden_1),nn.BatchNormld(n_hidden_1),nn.ReLU(True))
  self.layer2=nn.Sequential(nn.Linear(n_hidden_1,n_hidden_2),nn.BatchNormld(n_hidden_2),nn.ReLU(True))
  self.layer3=nn.Sequential(nn.Linear(n_hidden_2,out_dim))
 def forword(self,x):
  x=self.layer1(x)
  x=self.layer2(x)
  x=self.layer3(x)
  return x

训练网络,

import torch
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from torch import nn,optim
from torch.utils.data import DataLoader
from torchvision import datasets,transforms
#定义一些超参数
import net
batch_size=64
learning_rate=1e-2
num_epoches=20
#预处理
data_tf=transforms.Compose(
[transforms.ToTensor(),transforms.Normalize([0.5],[0.5])])#将图像转化成tensor,然后继续标准化,就是减均值,除以方差

#读取数据集
train_dataset=datasets.MNIST(root='./data',train=True,transform=data_tf,download=True)
test_dataset=datasets.MNIST(root='./data',train=False,transform=data_tf)
#使用内置的函数导入数据集
train_loader=DataLoader(train_dataset,batch_size=batch_size,shuffle=True)
test_loader=DataLoader(test_dataset,batch_size=batch_size,shuffle=False)

#导入网络,定义损失函数和优化方法
model=net.simpleNet(28*28,300,100,10)
if torch.cuda.is_available():#是否使用cuda加速
 model=model.cuda()
criterion=nn.CrossEntropyLoss()
optimizer=optim.SGD(model.parameters(),lr=learning_rate)
import net
n_epochs=5
for epoch in range(n_epochs):
 running_loss=0.0
 running_correct=0
 print("epoch {}/{}".format(epoch,n_epochs))
 print("-"*10)
 for data in train_loader:
  img,label=data
  img=img.view(img.size(0),-1)
  if torch.cuda.is_available():
   img=img.cuda()
   label=label.cuda()
  else:
   img=Variable(img)
   label=Variable(label)
  out=model(img)#得到前向传播的结果
  loss=criterion(out,label)#得到损失函数
  print_loss=loss.data.item()
  optimizer.zero_grad()#归0梯度
  loss.backward()#反向传播
  optimizer.step()#优化
  running_loss+=loss.item()
  epoch+=1
  if epoch%50==0:
   print('epoch:{},loss:{:.4f}'.format(epoch,loss.data.item()))

训练的结果截图如下:

pytorch三层全连接层实现手写字母识别方式

测试网络

#测试网络
model.eval()#将模型变成测试模式
eval_loss=0
eval_acc=0
for data in test_loader:
 img,label=data
 img=img.view(img.size(0),-1)#测试集不需要反向传播,所以可以在前项传播的时候释放内存,节约内存空间
 if torch.cuda.is_available():
  img=Variable(img,volatile=True).cuda()
  label=Variable(label,volatile=True).cuda()
 else:
  img=Variable(img,volatile=True)
  label=Variable(label,volatile=True)
 out=model(img)
 loss=criterion(out,label)
 eval_loss+=loss.item()*label.size(0)
 _,pred=torch.max(out,1)
 num_correct=(pred==label).sum()
 eval_acc+=num_correct.item()
print('test loss:{:.6f},ac:{:.6f}'.format(eval_loss/(len(test_dataset)),eval_acc/(len(test_dataset))))

pytorch三层全连接层实现手写字母识别方式

训练的时候,还可以加入一些dropout,正则化,修改隐藏层神经元的个数,增加隐藏层数,可以自己添加。

以上这篇pytorch三层全连接层实现手写字母识别方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现中文分词FMM算法实例
Jul 10 Python
Python实现查找系统盘中需要找的字符
Jul 14 Python
Django框架中数据的连锁查询和限制返回数据的方法
Jul 17 Python
python计算列表内各元素的个数实例
Jun 29 Python
详解pandas安装若干异常及解决方案总结
Jan 10 Python
Python列表的切片实例讲解
Aug 20 Python
python3 反射的四种基本方法解析
Aug 26 Python
Pycharm连接远程服务器过程图解
Apr 30 Python
keras的三种模型实现与区别说明
Jul 03 Python
基于python实现生成指定大小txt文档
Jul 20 Python
基于注解实现 SpringBoot 接口防刷的方法
Mar 02 Python
Python编程编写完善的命令行工具
Sep 15 Python
Python实现bilibili时间长度查询的示例代码
Jan 14 #Python
基于python监控程序是否关闭
Jan 14 #Python
关于pytorch中全连接神经网络搭建两种模式详解
Jan 14 #Python
使用Pytorch来拟合函数方式
Jan 14 #Python
pytorch 模拟关系拟合——回归实例
Jan 14 #Python
PyTorch实现AlexNet示例
Jan 14 #Python
Pytorch 实现focal_loss 多类别和二分类示例
Jan 14 #Python
You might like
apache+mysql+php+ssl服务器之完全安装攻略
2006/09/05 PHP
php中数组首字符过滤功能代码
2012/07/31 PHP
php使浏览器直接下载pdf文件的方法
2013/11/15 PHP
使用PHP连接多种数据库的实现代码(mysql,access,sqlserver,Oracle)
2016/12/21 PHP
php layui实现前端多图上传实例
2019/07/30 PHP
laravel多条件查询方法(and,or嵌套查询)
2019/10/09 PHP
PHP的new static和new self的区别与使用
2019/11/27 PHP
Yii框架安装简明教程
2020/05/15 PHP
PHP unset函数原理及使用方法解析
2020/08/14 PHP
jquery弹出层类代码分享
2013/12/27 Javascript
javascript调试之DOM断点调试法使用技巧分享
2014/04/15 Javascript
分析了一下JQuery中的extend方法实现原理
2015/02/27 Javascript
javascript实现简单的二级联动
2015/03/19 Javascript
JavaScript 常见安全漏洞和自动化检测技术
2015/08/21 Javascript
Knockout自定义绑定创建方法
2015/12/26 Javascript
jq实现左滑显示删除按钮,点击删除实现删除数据功能(推荐)
2016/08/23 Javascript
详解VueJs前后端分离跨域问题
2017/05/24 Javascript
react native 获取地理位置的方法示例
2018/08/28 Javascript
Python编程中对super函数的正确理解和用法解析
2016/07/02 Python
python 容器总结整理
2017/04/04 Python
Python线程池模块ThreadPoolExecutor用法分析
2018/12/28 Python
Python hexstring-list-str之间的转换方法
2019/06/12 Python
python用win32gui遍历窗口并设置窗口位置的方法
2019/07/26 Python
关于tensorflow的几种参数初始化方法小结
2020/01/04 Python
Django Form常用功能及代码示例
2020/10/13 Python
一款纯css3实现的非常实用的鼠标悬停特效演示
2014/11/05 HTML / CSS
玩具反斗城葡萄牙官方商城:Toys"R"Us葡萄牙
2016/10/21 全球购物
美国用餐电影院:Alamo Drafthouse Cinema
2020/01/23 全球购物
什么是Smarty变量操作符?如何使用Smarty变量操作符
2014/07/18 面试题
酒店前厅员工辞职信
2014/01/08 职场文书
幼儿评语大全
2014/04/30 职场文书
清明节扫墓活动总结
2015/02/09 职场文书
2015年大学生党员承诺书
2015/04/27 职场文书
行政处罚事先告知书
2015/07/01 职场文书
基于Java的MathML转图片的方法(示例代码)
2021/06/23 Java/Android
django中websocket的具体使用
2022/01/22 Python