Pytorch实现神经网络的分类方式


Posted in Python onJanuary 08, 2020

本文用于利用Pytorch实现神经网络的分类!!!

1.训练神经网络分类模型

import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torch.utils.data as Data
torch.manual_seed(1)#设置随机种子,使得每次生成的随机数是确定的
BATCH_SIZE = 5#设置batch size
 
#1.制作两类数据
n_data = torch.ones( 1000,2 )
x0 = torch.normal( 1.5*n_data, 1 )#均值为2 标准差为1
y0 = torch.zeros( 1000 )
 
x1 = torch.normal( -1.5*n_data,1 )#均值为-2 标准差为1
y1 = torch.ones( 1000 )
print("数据集维度:",x0.size(),y0.size())
 
#合并训练数据集,并转化数据类型为浮点型或整型
x = torch.cat( (x0,x1),0 ).type( torch.FloatTensor )
y = torch.cat( (y0,y1) ).type( torch.LongTensor )
print( "合并后的数据集维度:",x.data.size(), y.data.size() )
 
#当不使用batch size训练数据时,将Tensor放入Variable中
# x,y = Variable(x), Variable(y)
#绘制训练数据
# plt.scatter( x.data.numpy()[:,0], x.data.numpy()[:,1], c=y.data.numpy())
# plt.show()
 
#当使用batch size训练数据时,首先将tensor转化为Dataset格式
torch_dataset = Data.TensorDataset(x, y)
 
#将dataset放入DataLoader中
loader = Data.DataLoader(
 dataset=torch_dataset,
 batch_size = BATCH_SIZE,#设置batch size
 shuffle=True,#打乱数据
 num_workers=2#多线程读取数据
)
 
#2.前向传播过程
class Net(torch.nn.Module):#继承基类Module的属性和方法
 def __init__(self, input, hidden, output):
  super(Net, self).__init__()#继承__init__功能
  self.hidden = torch.nn.Linear(input, hidden)#隐层的线性输出
  self.out = torch.nn.Linear(hidden, output)#输出层线性输出
 def forward(self, x):
  x = F.relu(self.hidden(x))
  x = self.out(x)
  return x
 
# 训练模型的同时保存网络模型参数
def save():
 #3.利用自定义的前向传播过程设计网络,设置各层神经元数量
 # net = Net(input=2, hidden=10, output=2)
 # print("神经网络结构:",net)
 
 #3.快速搭建神经网络模型
 net = torch.nn.Sequential(
  torch.nn.Linear(2,10),#指定输入层和隐层结点,获得隐层线性输出
  torch.nn.ReLU(),#隐层非线性化
  torch.nn.Linear(10,2)#指定隐层和输出层结点,获得输出层线性输出
 )
 
 #4.设置优化算法、学习率
 # optimizer = torch.optim.SGD( net.parameters(), lr=0.2 )
 # optimizer = torch.optim.SGD( net.parameters(), lr=0.2, momentum=0.8 )
 # optimizer = torch.optim.RMSprop( net.parameters(), lr=0.2, alpha=0.9 )
 optimizer = torch.optim.Adam( net.parameters(), lr=0.2, betas=(0.9,0.99) )
 
 #5.设置损失函数
 loss_func = torch.nn.CrossEntropyLoss()
 
 plt.ion()#打开画布,可视化更新过程
 #6.迭代训练
 for epoch in range(2):
  for step, (batch_x, batch_y) in enumerate(loader):
   out = net(batch_x)#输入训练集,获得当前迭代输出值
   loss = loss_func(out, batch_y)#获得当前迭代的损失
 
   optimizer.zero_grad()#清除上次迭代的更新梯度
   loss.backward()#反向传播
   optimizer.step()#更新权重
 
   if step%200==0:
    plt.cla()#清空之前画布上的内容
    entire_out = net(x)#测试整个训练集
    #获得当前softmax层最大概率对应的索引值
    pred = torch.max(F.softmax(entire_out), 1)[1]
    #将二维压缩为一维
    pred_y = pred.data.numpy().squeeze()
    label_y = y.data.numpy()
    plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, cmap='RdYlGn')
    accuracy = sum(pred_y == label_y)/y.size()
    print("第 %d 个epoch,第 %d 次迭代,准确率为 %.2f"%(epoch+1, step/200+1, accuracy))
    #在指定位置添加文本
    plt.text(1.5, -4, 'Accuracy=%.2f' % accuracy, fontdict={'size': 15, 'color': 'red'})
    plt.pause(2)#图像显示时间
 
 #7.保存模型结构和参数
 torch.save(net, 'net.pkl')
 #7.只保存模型参数
 # torch.save(net.state_dict(), 'net_param.pkl')
 
 plt.ioff()#关闭画布
 plt.show()
 
if __name__ == '__main__':
 save()

2. 读取已训练好的模型测试数据

import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt
import torch.nn.functional as F
 
#制作数据
n_data = torch.ones( 100,2 )
x0 = torch.normal( 1.5*n_data, 1 )#均值为2 标准差为1
y0 = torch.zeros( 100 )
 
x1 = torch.normal( -1.5*n_data,1 )#均值为-2 标准差为1
y1 = torch.ones( 100 )
print("数据集维度:",x0.size(),y0.size())
 
#合并训练数据集,并转化数据类型为浮点型或整型
x = torch.cat( (x0,x1),0 ).type( torch.FloatTensor )
y = torch.cat( (y0,y1) ).type( torch.LongTensor )
print( "合并后的数据集维度:",x.data.size(), y.data.size() )
 
#将Tensor放入Variable中
x,y = Variable(x), Variable(y)
 
#载入模型和参数
def restore_net():
 net = torch.load('net.pkl')
 #获得载入模型的预测输出
 pred = net(x)
 # 获得当前softmax层最大概率对应的索引值
 pred = torch.max(F.softmax(pred), 1)[1]
 # 将二维压缩为一维
 pred_y = pred.data.numpy().squeeze()
 label_y = y.data.numpy()
 accuracy = sum(pred_y == label_y) / y.size()
 print("准确率为:",accuracy)
 plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, cmap='RdYlGn')
 plt.show()
#仅载入模型参数,需要先创建网络模型
def restore_param():
 net = torch.nn.Sequential(
  torch.nn.Linear(2,10),#指定输入层和隐层结点,获得隐层线性输出
  torch.nn.ReLU(),#隐层非线性化
  torch.nn.Linear(10,2)#指定隐层和输出层结点,获得输出层线性输出
 )
 
 net.load_state_dict( torch.load('net_param.pkl') )
 #获得载入模型的预测输出
 pred = net(x)
 # 获得当前softmax层最大概率对应的索引值
 pred = torch.max(F.softmax(pred), 1)[1]
 # 将二维压缩为一维
 pred_y = pred.data.numpy().squeeze()
 label_y = y.data.numpy()
 accuracy = sum(pred_y == label_y) / y.size()
 print("准确率为:",accuracy)
 plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, cmap='RdYlGn')
 plt.show()
 
if __name__ =='__main__':
 # restore_net()
 restore_param()

以上这篇Pytorch实现神经网络的分类方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现多线程采集的2个代码例子
Jul 07 Python
Python通过PIL获取图片主要颜色并和颜色库进行对比的方法
Mar 19 Python
Python实现针对中文排序的方法
May 09 Python
python使用sqlite3时游标使用方法
Mar 13 Python
Python socket套接字实现C/S模式远程命令执行功能案例
Jul 06 Python
Python3.7实现中控考勤机自动连接
Aug 28 Python
pycharm debug功能实现跳到循环末尾的方法
Nov 29 Python
Python将列表数据写入文件(txt, csv,excel)
Apr 03 Python
Django之腾讯云短信的实现
Jun 12 Python
基于django2.2连oracle11g解决版本冲突的问题
Jul 02 Python
对Keras自带Loss Function的深入研究
May 25 Python
Python办公自动化PPT批量转换操作
Sep 15 Python
python 爬取古诗文存入mysql数据库的方法
Jan 08 #Python
基于python3抓取pinpoint应用信息入库
Jan 08 #Python
Python PyInstaller安装和使用教程详解
Jan 08 #Python
关于Pytorch的MLP模块实现方式
Jan 07 #Python
PyTorch 普通卷积和空洞卷积实例
Jan 07 #Python
Pytorch中膨胀卷积的用法详解
Jan 07 #Python
Python urlopen()和urlretrieve()用法解析
Jan 07 #Python
You might like
用PHP调用数据库的存贮过程
2006/10/09 PHP
Zend公司全球首推PHP认证
2006/10/09 PHP
19个Android常用工具类汇总
2014/12/30 PHP
PHP7新特性foreach 修改示例介绍
2016/08/26 PHP
关于PHP中字符串与多进制转换函数的实例代码
2016/11/03 PHP
PHP正则替换函数preg_replace()报错:Notice Use of undefined constant的解决方法分析
2017/02/04 PHP
PHP实现动态压缩js与css文件的方法
2018/05/02 PHP
JavaScript CSS修改学习第二章 样式
2010/02/19 Javascript
js几秒以后倒计时跳转示例
2013/12/26 Javascript
JS实现div居中示例
2014/04/17 Javascript
JQuery操作textarea,input,select,checkbox方法
2015/09/02 Javascript
bootstrap选项卡使用方法解析
2017/01/11 Javascript
canvas压缩图片转换成base64格式输出文件流
2017/03/09 Javascript
vue 全选与反选的实现方法(无Bug 新手看过来)
2018/02/09 Javascript
react 不用插件实现数字滚动的效果示例
2020/04/14 Javascript
Python 匹配任意字符(包括换行符)的正则表达式写法
2009/10/29 Python
Python实现同时兼容老版和新版Socket协议的一个简单WebSocket服务器
2014/06/04 Python
Python3 Random模块代码详解
2017/12/04 Python
Python cookbook(数据结构与算法)将多个映射合并为单个映射的方法
2018/04/19 Python
PyTorch上实现卷积神经网络CNN的方法
2018/04/28 Python
Django项目中用JS实现加载子页面并传值的方法
2018/05/28 Python
使用python去除图片白色像素的实例
2019/12/12 Python
python爬虫开发之Request模块从安装到详细使用方法与实例全解
2020/03/09 Python
Python request操作步骤及代码实例
2020/04/13 Python
Python通过kerberos安全认证操作kafka方式
2020/06/06 Python
Python并发爬虫常用实现方法解析
2020/11/19 Python
使用html5实现表格实现标题合并的实例代码
2019/05/13 HTML / CSS
美国在线面料商店:Online Fabric Store
2018/07/26 全球购物
DELPHI中如何调用API,可举例说明
2014/01/16 面试题
入党申请人的自我鉴定
2013/12/01 职场文书
幼儿园门卫岗位职责范本
2014/07/02 职场文书
四风批评与自我批评发言稿
2014/10/14 职场文书
学校党的群众路线教育实践活动整改措施
2014/10/25 职场文书
2016年国培研修日志
2015/11/13 职场文书
Spring Cloud OAuth2实现自定义token返回格式
2022/06/25 Java/Android
python如何将mat文件转为png
2022/07/15 Python