Pytorch 神经网络—自定义数据集上实现教程


Posted in Python onJanuary 07, 2020

第一步、导入需要的包

import os
import scipy.io as sio
import numpy as np
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from torch.autograd import Variable
batchSize = 128 # batchsize的大小
niter = 10   # epoch的最大值

第二步、构建神经网络

Pytorch 神经网络—自定义数据集上实现教程

设神经网络为如上图所示,输入层4个神经元,两层隐含层各4个神经元,输出层一个神经。每一层网络所做的都是线性变换,即y=W×X+b;代码实现如下:

class Neuralnetwork(nn.Module):
  def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
    super(Neuralnetwork, self).__init__()
    self.layer1 = nn.Linear(in_dim, n_hidden_1)
    self.layer2 = nn.Linear(n_hidden_1, n_hidden_2)
    self.layer3 = nn.Linear(n_hidden_2, out_dim)
 
  def forward(self, x):
    x = x.view(x.size(0), -1)
    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    return x
 
model = Neuralnetwork(1*3, 4, 4, 1)
 
print(model) # net architecture
Neuralnetwork(
 (layer1): Linear(in_features=3, out_features=4, bias=True)
 (layer2): Linear(in_features=4, out_features=4, bias=True)
 (layer3): Linear(in_features=4, out_features=1, bias=True)
)

​​ 第三步、读取数据

自定义的数据为demo_SBPFea.mat,是MATLAB保存的数据格式,其存储的内容如下:包括fea(1000*3)和sbp(1000*1)两个数组;fea为特征向量,行为样本数,列为特征宽度;sbp为标签

Pytorch 神经网络—自定义数据集上实现教程

class SBPEstimateDataset(Dataset):
 
  def __init__(self, ext='demo'):
  
    data = sio.loadmat(ext+'_SBPFea.mat')
    self.fea = data['fea']
    self.sbp = data['sbp']
    
  def __len__(self):
    
    return len(self.sbp)
 
  def __getitem__(self, idx):
 
    fea = self.fea[idx]
    sbp = self.sbp[idx]
    """Convert ndarrays to Tensors."""
    return {'fea': torch.from_numpy(fea).float(),
        'sbp': torch.from_numpy(sbp).float()
        }
    
train_dataset = SBPEstimateDataset(ext='demo')
train_loader = DataLoader(train_dataset, batch_size=batchSize, # 分批次训练
             shuffle=True, num_workers=int(8))

整个数据样本为1000,以batchSize = 128划分,分为8份,前7份为104个样本,第8份则为104个样本。在网络训练过程中,是一份数据一份数据进行训练的

第四步、模型训练

# 优化器,Adam 
optimizer = optim.Adam(list(model.parameters()), lr=0.0001, betas=(0.9, 0.999),weight_decay=0.004) 
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.997) 
criterion = nn.MSELoss() # loss function 
 
if torch.cuda.is_available(): # 有GPU,则用GPU计算
   model.cuda() 
   criterion.cuda() 
 
for epoch in range(niter): 
   losses = [] 
   ERROR_Train = [] 
   model.train() 
   for i, data in enumerate(train_loader, 0): 
     model.zero_grad()# 首先提取清零 
     real_cpu, label_cpu = data['fea'], data['sbp'] 
 
     if torch.cuda.is_available():# CUDA可用情况下,将Tensor 在GPU上运行 
       real_cpu = real_cpu.cuda() 
       label_cpu = label_cpu.cuda() 
 
 
       input=real_cpu 
       label=label_cpu 
 
       inputv = Variable(input) 
       labelv = Variable(label) 
 
       output = model(inputv) 
       err = criterion(output, labelv) 
       err.backward() 
       optimizer.step() 
 
       losses.append(err.data[0]) 
 
       error = output.data-label+ 1e-12 
       ERROR_Train.extend(error) 
 
   MAE = np.average(np.abs(np.array(ERROR_Train))) 
   ME = np.average(np.array(ERROR_Train)) 
   STD = np.std(np.array(ERROR_Train)) 
 
   print('[%d/%d] Loss: %.4f MAE: %.4f Mean Error: %.4f STD: %.4f' % ( 
   epoch, niter, np.average(losses), MAE, ME, STD))
​​
[0/10] Loss: 18384.6699 MAE: 135.3871 Mean Error: -135.3871 STD: 7.5580
[1/10] Loss: 17063.0215 MAE: 130.4145 Mean Error: -130.4145 STD: 7.8918
[2/10] Loss: 13689.1934 MAE: 116.6625 Mean Error: -116.6625 STD: 9.7946
[3/10] Loss: 8192.9053 MAE: 89.6611 Mean Error: -89.6611 STD: 12.9911
[4/10] Loss: 2979.1340 MAE: 52.5410 Mean Error: -52.5279 STD: 15.0930
[5/10] Loss: 599.7094 MAE: 22.2735 Mean Error: -19.9979 STD: 14.2069
[6/10] Loss: 207.2831 MAE: 11.2394 Mean Error: -4.8821 STD: 13.5528
[7/10] Loss: 189.8173 MAE: 9.8020 Mean Error: -1.2357 STD: 13.7095
[8/10] Loss: 188.3376 MAE: 9.6512 Mean Error: -0.6498 STD: 13.7075
[9/10] Loss: 186.8393 MAE: 9.6946 Mean Error: -1.0850 STD: 13.6332​

以上这篇Pytorch 神经网络—自定义数据集上实现教程就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中不同进制互相转换(二进制、八进制、十进制和十六进制)
Apr 05 Python
python导出chrome书签到markdown文件的实例代码
Dec 27 Python
Python文本统计功能之西游记用字统计操作示例
May 07 Python
Appium+python自动化之连接模拟器并启动淘宝APP(超详解)
Jun 17 Python
python 求某条线上特定x值或y值的点坐标方法
Jul 09 Python
Python @property使用方法解析
Sep 17 Python
PyTorch预训练的实现
Sep 18 Python
Django中modelform组件实例用法总结
Feb 10 Python
pytorch中图像的数据格式实例
Feb 11 Python
Python+OpenCV实现图像的全景拼接
Mar 05 Python
keras绘制acc和loss曲线图实例
Jun 15 Python
python实现批量移动文件
Apr 05 Python
浅谈Python访问MySQL的正确姿势
Jan 07 #Python
pytorch自定义二值化网络层方式
Jan 07 #Python
Pytorch: 自定义网络层实例
Jan 07 #Python
Python StringIO如何在内存中读写str
Jan 07 #Python
Python内置数据类型list各方法的性能测试过程解析
Jan 07 #Python
python模拟实现斗地主发牌
Jan 07 #Python
python全局变量引用与修改过程解析
Jan 07 #Python
You might like
用PHP制作静态网站的模板框架(三)
2006/10/09 PHP
一个连接两个不同MYSQL数据库的PHP程序
2006/10/09 PHP
php explode函数实例代码
2012/02/27 PHP
PHP获取MSN好友列表类的实现代码
2013/06/23 PHP
php绘图之加载外部图片的方法
2015/01/24 PHP
PHP实现的最大正向匹配算法示例
2017/12/19 PHP
Gambit vs ForZe BO3 第一场 2.13
2021/03/10 DOTA
Javascript Jquery 遍历Json的实现代码
2010/03/31 Javascript
js局部刷新页面时间具体实现
2013/07/04 Javascript
IE下写xml文件的两种方式(fso/saveAs)
2013/08/05 Javascript
js调试系列 控制台命令行API使用方法
2014/06/18 Javascript
JQuery实现动态添加删除评论的方法
2015/05/18 Javascript
基于JavaScript实现移动端TAB触屏切换效果
2015/10/20 Javascript
js当前页面登录注册框,固定div,底层阴影的实例代码
2016/10/04 Javascript
Ionic学习日记实现验证码倒计时
2018/02/08 Javascript
解决vue+element 键盘回车事件导致页面刷新的问题
2018/08/25 Javascript
clipboard在vue中的使用的方法示例
2018/10/19 Javascript
基于vue.js实现分页查询功能
2018/12/29 Javascript
vue自定义指令之面板拖拽的实现
2019/04/14 Javascript
使用vuex存储用户信息到localStorage的实例
2019/11/11 Javascript
Vue+ElementUI使用vue-pdf实现预览功能
2019/11/26 Javascript
webpack proxy 使用(代理的使用)
2020/01/10 Javascript
Python中的True,False条件判断实例分析
2015/01/12 Python
Python中使用装饰器时需要注意的一些问题
2015/05/11 Python
python中字符串前面加r的作用
2015/06/04 Python
安装ElasticSearch搜索工具并配置Python驱动的方法
2015/12/22 Python
python获取时间及时间格式转换问题实例代码详解
2018/12/06 Python
浅谈pytorch grad_fn以及权重梯度不更新的问题
2019/08/20 Python
Python之关于类变量的两种赋值区别详解
2020/03/12 Python
伊芙丽官方旗舰店:中国淑女一线品牌
2017/12/01 全球购物
大学毕业生通用自荐信范文
2013/10/31 职场文书
美术课外活动总结
2014/07/08 职场文书
投标授权委托书范文
2014/08/02 职场文书
使用goaccess分析nginx日志的详细方法
2021/07/09 Servers
Win11筛选键导致键盘失灵怎么解决? Win11关闭筛选键的技巧
2022/04/08 数码科技
JAVA springCloud项目搭建流程
2022/05/11 Java/Android