Pytorch实现神经网络的分类方式


Posted in Python onJanuary 08, 2020

本文用于利用Pytorch实现神经网络的分类!!!

1.训练神经网络分类模型

import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torch.utils.data as Data
torch.manual_seed(1)#设置随机种子,使得每次生成的随机数是确定的
BATCH_SIZE = 5#设置batch size
 
#1.制作两类数据
n_data = torch.ones( 1000,2 )
x0 = torch.normal( 1.5*n_data, 1 )#均值为2 标准差为1
y0 = torch.zeros( 1000 )
 
x1 = torch.normal( -1.5*n_data,1 )#均值为-2 标准差为1
y1 = torch.ones( 1000 )
print("数据集维度:",x0.size(),y0.size())
 
#合并训练数据集,并转化数据类型为浮点型或整型
x = torch.cat( (x0,x1),0 ).type( torch.FloatTensor )
y = torch.cat( (y0,y1) ).type( torch.LongTensor )
print( "合并后的数据集维度:",x.data.size(), y.data.size() )
 
#当不使用batch size训练数据时,将Tensor放入Variable中
# x,y = Variable(x), Variable(y)
#绘制训练数据
# plt.scatter( x.data.numpy()[:,0], x.data.numpy()[:,1], c=y.data.numpy())
# plt.show()
 
#当使用batch size训练数据时,首先将tensor转化为Dataset格式
torch_dataset = Data.TensorDataset(x, y)
 
#将dataset放入DataLoader中
loader = Data.DataLoader(
 dataset=torch_dataset,
 batch_size = BATCH_SIZE,#设置batch size
 shuffle=True,#打乱数据
 num_workers=2#多线程读取数据
)
 
#2.前向传播过程
class Net(torch.nn.Module):#继承基类Module的属性和方法
 def __init__(self, input, hidden, output):
  super(Net, self).__init__()#继承__init__功能
  self.hidden = torch.nn.Linear(input, hidden)#隐层的线性输出
  self.out = torch.nn.Linear(hidden, output)#输出层线性输出
 def forward(self, x):
  x = F.relu(self.hidden(x))
  x = self.out(x)
  return x
 
# 训练模型的同时保存网络模型参数
def save():
 #3.利用自定义的前向传播过程设计网络,设置各层神经元数量
 # net = Net(input=2, hidden=10, output=2)
 # print("神经网络结构:",net)
 
 #3.快速搭建神经网络模型
 net = torch.nn.Sequential(
  torch.nn.Linear(2,10),#指定输入层和隐层结点,获得隐层线性输出
  torch.nn.ReLU(),#隐层非线性化
  torch.nn.Linear(10,2)#指定隐层和输出层结点,获得输出层线性输出
 )
 
 #4.设置优化算法、学习率
 # optimizer = torch.optim.SGD( net.parameters(), lr=0.2 )
 # optimizer = torch.optim.SGD( net.parameters(), lr=0.2, momentum=0.8 )
 # optimizer = torch.optim.RMSprop( net.parameters(), lr=0.2, alpha=0.9 )
 optimizer = torch.optim.Adam( net.parameters(), lr=0.2, betas=(0.9,0.99) )
 
 #5.设置损失函数
 loss_func = torch.nn.CrossEntropyLoss()
 
 plt.ion()#打开画布,可视化更新过程
 #6.迭代训练
 for epoch in range(2):
  for step, (batch_x, batch_y) in enumerate(loader):
   out = net(batch_x)#输入训练集,获得当前迭代输出值
   loss = loss_func(out, batch_y)#获得当前迭代的损失
 
   optimizer.zero_grad()#清除上次迭代的更新梯度
   loss.backward()#反向传播
   optimizer.step()#更新权重
 
   if step%200==0:
    plt.cla()#清空之前画布上的内容
    entire_out = net(x)#测试整个训练集
    #获得当前softmax层最大概率对应的索引值
    pred = torch.max(F.softmax(entire_out), 1)[1]
    #将二维压缩为一维
    pred_y = pred.data.numpy().squeeze()
    label_y = y.data.numpy()
    plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, cmap='RdYlGn')
    accuracy = sum(pred_y == label_y)/y.size()
    print("第 %d 个epoch,第 %d 次迭代,准确率为 %.2f"%(epoch+1, step/200+1, accuracy))
    #在指定位置添加文本
    plt.text(1.5, -4, 'Accuracy=%.2f' % accuracy, fontdict={'size': 15, 'color': 'red'})
    plt.pause(2)#图像显示时间
 
 #7.保存模型结构和参数
 torch.save(net, 'net.pkl')
 #7.只保存模型参数
 # torch.save(net.state_dict(), 'net_param.pkl')
 
 plt.ioff()#关闭画布
 plt.show()
 
if __name__ == '__main__':
 save()

2. 读取已训练好的模型测试数据

import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt
import torch.nn.functional as F
 
#制作数据
n_data = torch.ones( 100,2 )
x0 = torch.normal( 1.5*n_data, 1 )#均值为2 标准差为1
y0 = torch.zeros( 100 )
 
x1 = torch.normal( -1.5*n_data,1 )#均值为-2 标准差为1
y1 = torch.ones( 100 )
print("数据集维度:",x0.size(),y0.size())
 
#合并训练数据集,并转化数据类型为浮点型或整型
x = torch.cat( (x0,x1),0 ).type( torch.FloatTensor )
y = torch.cat( (y0,y1) ).type( torch.LongTensor )
print( "合并后的数据集维度:",x.data.size(), y.data.size() )
 
#将Tensor放入Variable中
x,y = Variable(x), Variable(y)
 
#载入模型和参数
def restore_net():
 net = torch.load('net.pkl')
 #获得载入模型的预测输出
 pred = net(x)
 # 获得当前softmax层最大概率对应的索引值
 pred = torch.max(F.softmax(pred), 1)[1]
 # 将二维压缩为一维
 pred_y = pred.data.numpy().squeeze()
 label_y = y.data.numpy()
 accuracy = sum(pred_y == label_y) / y.size()
 print("准确率为:",accuracy)
 plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, cmap='RdYlGn')
 plt.show()
#仅载入模型参数,需要先创建网络模型
def restore_param():
 net = torch.nn.Sequential(
  torch.nn.Linear(2,10),#指定输入层和隐层结点,获得隐层线性输出
  torch.nn.ReLU(),#隐层非线性化
  torch.nn.Linear(10,2)#指定隐层和输出层结点,获得输出层线性输出
 )
 
 net.load_state_dict( torch.load('net_param.pkl') )
 #获得载入模型的预测输出
 pred = net(x)
 # 获得当前softmax层最大概率对应的索引值
 pred = torch.max(F.softmax(pred), 1)[1]
 # 将二维压缩为一维
 pred_y = pred.data.numpy().squeeze()
 label_y = y.data.numpy()
 accuracy = sum(pred_y == label_y) / y.size()
 print("准确率为:",accuracy)
 plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, cmap='RdYlGn')
 plt.show()
 
if __name__ =='__main__':
 # restore_net()
 restore_param()

以上这篇Pytorch实现神经网络的分类方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python查找第k小元素代码分享
Dec 18 Python
Python3处理文件中每个词的方法
May 22 Python
python3之微信文章爬虫实例讲解
Jul 12 Python
浅析python打包工具distutils、setuptools
Apr 20 Python
通过python将大量文件按修改时间分类的方法
Oct 17 Python
Django 拆分model和view的实现方法
Aug 16 Python
使用python无账号无限制获取企查查信息的实例代码
Apr 17 Python
如何使用Python处理HDF格式数据及可视化问题
Jun 24 Python
浅谈对python中if、elif、else的误解
Aug 20 Python
Scrapy项目实战之爬取某社区用户详情
Sep 17 Python
Matplotlib animation模块实现动态图
Feb 25 Python
用Python监控你的朋友都在浏览哪些网站?
May 27 Python
python 爬取古诗文存入mysql数据库的方法
Jan 08 #Python
基于python3抓取pinpoint应用信息入库
Jan 08 #Python
Python PyInstaller安装和使用教程详解
Jan 08 #Python
关于Pytorch的MLP模块实现方式
Jan 07 #Python
PyTorch 普通卷积和空洞卷积实例
Jan 07 #Python
Pytorch中膨胀卷积的用法详解
Jan 07 #Python
Python urlopen()和urlretrieve()用法解析
Jan 07 #Python
You might like
php trim 去除空字符的定义与语法介绍
2010/05/31 PHP
关于PHP session 存储方式的详细介绍
2013/06/25 PHP
php生成静态html页面的方法(2种方法)
2015/09/14 PHP
PHP编程获取各个时间段具体时间的方法
2017/05/26 PHP
JTrackBar水平拖动效果
2007/07/15 Javascript
JavaScript 计算当天是本年本月的第几周
2009/03/22 Javascript
JQuery获取元素文档大小、偏移和位置和滚动条位置的方法集合
2010/01/12 Javascript
屏蔽Flash右键信息的js代码
2010/01/17 Javascript
JS执行删除前的判断代码
2014/02/18 Javascript
jQuery简单实现隐藏以及显示特效
2015/02/26 Javascript
jQuery实现从身份证号中获取出生日期和性别的方法分析
2016/02/25 Javascript
jQuery插件HighCharts实现的2D条状图效果示例【附demo源码下载】
2017/03/15 Javascript
微信小程序录音与播放录音功能
2017/12/25 Javascript
layui点击按钮添加可编辑的一行方法
2018/08/15 Javascript
微信用户访问小程序的登录过程详解
2019/09/20 Javascript
JavaScript 函数用法详解【函数定义、参数、绑定、作用域、闭包等】
2020/05/12 Javascript
JavaScript如何判断对象有某属性
2020/07/03 Javascript
Python中的filter()函数的用法
2015/04/27 Python
Python 装饰器深入理解
2017/03/16 Python
解决python大批量读写.doc文件的问题
2018/05/08 Python
浅谈解除装饰器作用(python3新增)
2018/10/15 Python
python 通过类中一个方法获取另一个方法变量的实例
2019/01/22 Python
在Pandas中DataFrame数据合并,连接(concat,merge,join)的实例
2019/01/29 Python
Flask框架踩坑之ajax跨域请求实现
2019/02/22 Python
python 整数越界问题详解
2019/06/27 Python
Python变量访问权限控制详解
2019/06/29 Python
Django 实现前端图片压缩功能的方法
2019/08/07 Python
python 进程间数据共享multiProcess.Manger实现解析
2019/09/23 Python
解决python gdal投影坐标系转换的问题
2020/01/17 Python
Python中and和or如何使用
2020/05/28 Python
Ryderwear澳洲官网:澳大利亚高端健身训练装备品牌
2018/09/18 全球购物
教师年终个人自我评价
2013/10/04 职场文书
党委班子剖析材料
2014/08/21 职场文书
社区务虚会发言材料
2014/10/20 职场文书
民主生活会意见
2015/06/05 职场文书
2016领导干部廉洁自律心得体会
2016/01/13 职场文书