Pytorch实现神经网络的分类方式


Posted in Python onJanuary 08, 2020

本文用于利用Pytorch实现神经网络的分类!!!

1.训练神经网络分类模型

import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torch.utils.data as Data
torch.manual_seed(1)#设置随机种子,使得每次生成的随机数是确定的
BATCH_SIZE = 5#设置batch size
 
#1.制作两类数据
n_data = torch.ones( 1000,2 )
x0 = torch.normal( 1.5*n_data, 1 )#均值为2 标准差为1
y0 = torch.zeros( 1000 )
 
x1 = torch.normal( -1.5*n_data,1 )#均值为-2 标准差为1
y1 = torch.ones( 1000 )
print("数据集维度:",x0.size(),y0.size())
 
#合并训练数据集,并转化数据类型为浮点型或整型
x = torch.cat( (x0,x1),0 ).type( torch.FloatTensor )
y = torch.cat( (y0,y1) ).type( torch.LongTensor )
print( "合并后的数据集维度:",x.data.size(), y.data.size() )
 
#当不使用batch size训练数据时,将Tensor放入Variable中
# x,y = Variable(x), Variable(y)
#绘制训练数据
# plt.scatter( x.data.numpy()[:,0], x.data.numpy()[:,1], c=y.data.numpy())
# plt.show()
 
#当使用batch size训练数据时,首先将tensor转化为Dataset格式
torch_dataset = Data.TensorDataset(x, y)
 
#将dataset放入DataLoader中
loader = Data.DataLoader(
 dataset=torch_dataset,
 batch_size = BATCH_SIZE,#设置batch size
 shuffle=True,#打乱数据
 num_workers=2#多线程读取数据
)
 
#2.前向传播过程
class Net(torch.nn.Module):#继承基类Module的属性和方法
 def __init__(self, input, hidden, output):
  super(Net, self).__init__()#继承__init__功能
  self.hidden = torch.nn.Linear(input, hidden)#隐层的线性输出
  self.out = torch.nn.Linear(hidden, output)#输出层线性输出
 def forward(self, x):
  x = F.relu(self.hidden(x))
  x = self.out(x)
  return x
 
# 训练模型的同时保存网络模型参数
def save():
 #3.利用自定义的前向传播过程设计网络,设置各层神经元数量
 # net = Net(input=2, hidden=10, output=2)
 # print("神经网络结构:",net)
 
 #3.快速搭建神经网络模型
 net = torch.nn.Sequential(
  torch.nn.Linear(2,10),#指定输入层和隐层结点,获得隐层线性输出
  torch.nn.ReLU(),#隐层非线性化
  torch.nn.Linear(10,2)#指定隐层和输出层结点,获得输出层线性输出
 )
 
 #4.设置优化算法、学习率
 # optimizer = torch.optim.SGD( net.parameters(), lr=0.2 )
 # optimizer = torch.optim.SGD( net.parameters(), lr=0.2, momentum=0.8 )
 # optimizer = torch.optim.RMSprop( net.parameters(), lr=0.2, alpha=0.9 )
 optimizer = torch.optim.Adam( net.parameters(), lr=0.2, betas=(0.9,0.99) )
 
 #5.设置损失函数
 loss_func = torch.nn.CrossEntropyLoss()
 
 plt.ion()#打开画布,可视化更新过程
 #6.迭代训练
 for epoch in range(2):
  for step, (batch_x, batch_y) in enumerate(loader):
   out = net(batch_x)#输入训练集,获得当前迭代输出值
   loss = loss_func(out, batch_y)#获得当前迭代的损失
 
   optimizer.zero_grad()#清除上次迭代的更新梯度
   loss.backward()#反向传播
   optimizer.step()#更新权重
 
   if step%200==0:
    plt.cla()#清空之前画布上的内容
    entire_out = net(x)#测试整个训练集
    #获得当前softmax层最大概率对应的索引值
    pred = torch.max(F.softmax(entire_out), 1)[1]
    #将二维压缩为一维
    pred_y = pred.data.numpy().squeeze()
    label_y = y.data.numpy()
    plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, cmap='RdYlGn')
    accuracy = sum(pred_y == label_y)/y.size()
    print("第 %d 个epoch,第 %d 次迭代,准确率为 %.2f"%(epoch+1, step/200+1, accuracy))
    #在指定位置添加文本
    plt.text(1.5, -4, 'Accuracy=%.2f' % accuracy, fontdict={'size': 15, 'color': 'red'})
    plt.pause(2)#图像显示时间
 
 #7.保存模型结构和参数
 torch.save(net, 'net.pkl')
 #7.只保存模型参数
 # torch.save(net.state_dict(), 'net_param.pkl')
 
 plt.ioff()#关闭画布
 plt.show()
 
if __name__ == '__main__':
 save()

2. 读取已训练好的模型测试数据

import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt
import torch.nn.functional as F
 
#制作数据
n_data = torch.ones( 100,2 )
x0 = torch.normal( 1.5*n_data, 1 )#均值为2 标准差为1
y0 = torch.zeros( 100 )
 
x1 = torch.normal( -1.5*n_data,1 )#均值为-2 标准差为1
y1 = torch.ones( 100 )
print("数据集维度:",x0.size(),y0.size())
 
#合并训练数据集,并转化数据类型为浮点型或整型
x = torch.cat( (x0,x1),0 ).type( torch.FloatTensor )
y = torch.cat( (y0,y1) ).type( torch.LongTensor )
print( "合并后的数据集维度:",x.data.size(), y.data.size() )
 
#将Tensor放入Variable中
x,y = Variable(x), Variable(y)
 
#载入模型和参数
def restore_net():
 net = torch.load('net.pkl')
 #获得载入模型的预测输出
 pred = net(x)
 # 获得当前softmax层最大概率对应的索引值
 pred = torch.max(F.softmax(pred), 1)[1]
 # 将二维压缩为一维
 pred_y = pred.data.numpy().squeeze()
 label_y = y.data.numpy()
 accuracy = sum(pred_y == label_y) / y.size()
 print("准确率为:",accuracy)
 plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, cmap='RdYlGn')
 plt.show()
#仅载入模型参数,需要先创建网络模型
def restore_param():
 net = torch.nn.Sequential(
  torch.nn.Linear(2,10),#指定输入层和隐层结点,获得隐层线性输出
  torch.nn.ReLU(),#隐层非线性化
  torch.nn.Linear(10,2)#指定隐层和输出层结点,获得输出层线性输出
 )
 
 net.load_state_dict( torch.load('net_param.pkl') )
 #获得载入模型的预测输出
 pred = net(x)
 # 获得当前softmax层最大概率对应的索引值
 pred = torch.max(F.softmax(pred), 1)[1]
 # 将二维压缩为一维
 pred_y = pred.data.numpy().squeeze()
 label_y = y.data.numpy()
 accuracy = sum(pred_y == label_y) / y.size()
 print("准确率为:",accuracy)
 plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, cmap='RdYlGn')
 plt.show()
 
if __name__ =='__main__':
 # restore_net()
 restore_param()

以上这篇Pytorch实现神经网络的分类方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python入门篇之编程习惯与特点
Oct 17 Python
python开发之list操作实例分析
Feb 22 Python
TensorFLow用Saver保存和恢复变量
Mar 10 Python
python用列表生成式写嵌套循环的方法
Nov 08 Python
Python和Go语言的区别总结
Feb 20 Python
python把1变成01的步骤总结
Feb 27 Python
python实现AES加密与解密
Mar 28 Python
Python中*args和**kwargs的区别详解
Sep 17 Python
wxpython+pymysql实现用户登陆功能
Nov 19 Python
Django模板获取field的verbose_name实例
May 19 Python
OpenCV4.1.0+VS2017环境配置的方法步骤
Jul 09 Python
Python自动发送和收取邮件的方法
Aug 12 Python
python 爬取古诗文存入mysql数据库的方法
Jan 08 #Python
基于python3抓取pinpoint应用信息入库
Jan 08 #Python
Python PyInstaller安装和使用教程详解
Jan 08 #Python
关于Pytorch的MLP模块实现方式
Jan 07 #Python
PyTorch 普通卷积和空洞卷积实例
Jan 07 #Python
Pytorch中膨胀卷积的用法详解
Jan 07 #Python
Python urlopen()和urlretrieve()用法解析
Jan 07 #Python
You might like
我常用的几个类
2006/10/09 PHP
php面向对象全攻略 (三)特殊的引用“$this”的使用
2009/09/30 PHP
php读取txt文件组成SQL并插入数据库的代码(原创自Zjmainstay)
2012/07/31 PHP
PHP采集类Snoopy抓取图片实例
2014/06/19 PHP
PHP递归遍历多维数组实现无限分类的方法
2016/05/06 PHP
Laravel中错误与异常处理的用法示例
2018/09/16 PHP
jquery HotKeys轻松搞定键盘事件代码
2008/08/30 Javascript
JQuery Study Notes 学习笔记(一)
2010/08/04 Javascript
原生Js实现按的数据源均分时间点幻灯片效果(已封装)
2010/12/28 Javascript
jQuery+JSON+jPlayer实现QQ空间音乐查询功能示例
2013/06/17 Javascript
js 判断计算字符串长度/判断空的简单方法
2013/08/05 Javascript
JavaScript判断变量是对象还是数组的方法
2014/08/28 Javascript
如何利用JQuery实现从底部回到顶部的功能
2016/12/27 Javascript
javaScript基础详解
2017/01/19 Javascript
JavaScript实现经纬度转换成地址功能
2017/03/28 Javascript
详解vue中computed 和 watch的异同
2017/06/30 Javascript
基于Vue.js与WordPress Rest API构建单页应用详解
2019/09/16 Javascript
测试、预发布后用python检测网页是否有日常链接
2014/06/03 Python
Python和GO语言实现的消息摘要算法示例
2015/03/10 Python
Python语言生成水仙花数代码示例
2017/12/18 Python
python 接口测试response返回数据对比的方法
2018/02/11 Python
python儿童学游戏编程知识点总结
2019/06/03 Python
Django实现auth模块下的登录注册与注销功能
2019/10/10 Python
Python中有几个关键字
2020/06/04 Python
python调用有道智云API实现文件批量翻译
2020/10/10 Python
python regex库实例用法总结
2021/01/03 Python
基于zepto的插件之移动端无缝向上滚动并上下触摸滑动实例代码
2016/12/20 HTML / CSS
美国隐形眼镜网上商店:Lens.com
2019/09/03 全球购物
安德玛菲律宾官网:Under Armour菲律宾
2020/07/28 全球购物
英语专业个人求职自荐信
2013/09/21 职场文书
环保建议书300字
2014/05/14 职场文书
教师党员先进性教育自我剖析材料思想汇报
2014/09/24 职场文书
夫妻分居协议书范文
2014/11/26 职场文书
课外活动实习计划
2015/01/19 职场文书
2019年朋友圈经典励志语录50条
2019/07/05 职场文书
浅谈JS的二进制家族
2021/05/09 Javascript