Pytorch 神经网络—自定义数据集上实现教程


Posted in Python onJanuary 07, 2020

第一步、导入需要的包

import os
import scipy.io as sio
import numpy as np
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from torch.autograd import Variable
batchSize = 128 # batchsize的大小
niter = 10   # epoch的最大值

第二步、构建神经网络

Pytorch 神经网络—自定义数据集上实现教程

设神经网络为如上图所示,输入层4个神经元,两层隐含层各4个神经元,输出层一个神经。每一层网络所做的都是线性变换,即y=W×X+b;代码实现如下:

class Neuralnetwork(nn.Module):
  def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
    super(Neuralnetwork, self).__init__()
    self.layer1 = nn.Linear(in_dim, n_hidden_1)
    self.layer2 = nn.Linear(n_hidden_1, n_hidden_2)
    self.layer3 = nn.Linear(n_hidden_2, out_dim)
 
  def forward(self, x):
    x = x.view(x.size(0), -1)
    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    return x
 
model = Neuralnetwork(1*3, 4, 4, 1)
 
print(model) # net architecture
Neuralnetwork(
 (layer1): Linear(in_features=3, out_features=4, bias=True)
 (layer2): Linear(in_features=4, out_features=4, bias=True)
 (layer3): Linear(in_features=4, out_features=1, bias=True)
)

​​ 第三步、读取数据

自定义的数据为demo_SBPFea.mat,是MATLAB保存的数据格式,其存储的内容如下:包括fea(1000*3)和sbp(1000*1)两个数组;fea为特征向量,行为样本数,列为特征宽度;sbp为标签

Pytorch 神经网络—自定义数据集上实现教程

class SBPEstimateDataset(Dataset):
 
  def __init__(self, ext='demo'):
  
    data = sio.loadmat(ext+'_SBPFea.mat')
    self.fea = data['fea']
    self.sbp = data['sbp']
    
  def __len__(self):
    
    return len(self.sbp)
 
  def __getitem__(self, idx):
 
    fea = self.fea[idx]
    sbp = self.sbp[idx]
    """Convert ndarrays to Tensors."""
    return {'fea': torch.from_numpy(fea).float(),
        'sbp': torch.from_numpy(sbp).float()
        }
    
train_dataset = SBPEstimateDataset(ext='demo')
train_loader = DataLoader(train_dataset, batch_size=batchSize, # 分批次训练
             shuffle=True, num_workers=int(8))

整个数据样本为1000,以batchSize = 128划分,分为8份,前7份为104个样本,第8份则为104个样本。在网络训练过程中,是一份数据一份数据进行训练的

第四步、模型训练

# 优化器,Adam 
optimizer = optim.Adam(list(model.parameters()), lr=0.0001, betas=(0.9, 0.999),weight_decay=0.004) 
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.997) 
criterion = nn.MSELoss() # loss function 
 
if torch.cuda.is_available(): # 有GPU,则用GPU计算
   model.cuda() 
   criterion.cuda() 
 
for epoch in range(niter): 
   losses = [] 
   ERROR_Train = [] 
   model.train() 
   for i, data in enumerate(train_loader, 0): 
     model.zero_grad()# 首先提取清零 
     real_cpu, label_cpu = data['fea'], data['sbp'] 
 
     if torch.cuda.is_available():# CUDA可用情况下,将Tensor 在GPU上运行 
       real_cpu = real_cpu.cuda() 
       label_cpu = label_cpu.cuda() 
 
 
       input=real_cpu 
       label=label_cpu 
 
       inputv = Variable(input) 
       labelv = Variable(label) 
 
       output = model(inputv) 
       err = criterion(output, labelv) 
       err.backward() 
       optimizer.step() 
 
       losses.append(err.data[0]) 
 
       error = output.data-label+ 1e-12 
       ERROR_Train.extend(error) 
 
   MAE = np.average(np.abs(np.array(ERROR_Train))) 
   ME = np.average(np.array(ERROR_Train)) 
   STD = np.std(np.array(ERROR_Train)) 
 
   print('[%d/%d] Loss: %.4f MAE: %.4f Mean Error: %.4f STD: %.4f' % ( 
   epoch, niter, np.average(losses), MAE, ME, STD))
​​
[0/10] Loss: 18384.6699 MAE: 135.3871 Mean Error: -135.3871 STD: 7.5580
[1/10] Loss: 17063.0215 MAE: 130.4145 Mean Error: -130.4145 STD: 7.8918
[2/10] Loss: 13689.1934 MAE: 116.6625 Mean Error: -116.6625 STD: 9.7946
[3/10] Loss: 8192.9053 MAE: 89.6611 Mean Error: -89.6611 STD: 12.9911
[4/10] Loss: 2979.1340 MAE: 52.5410 Mean Error: -52.5279 STD: 15.0930
[5/10] Loss: 599.7094 MAE: 22.2735 Mean Error: -19.9979 STD: 14.2069
[6/10] Loss: 207.2831 MAE: 11.2394 Mean Error: -4.8821 STD: 13.5528
[7/10] Loss: 189.8173 MAE: 9.8020 Mean Error: -1.2357 STD: 13.7095
[8/10] Loss: 188.3376 MAE: 9.6512 Mean Error: -0.6498 STD: 13.7075
[9/10] Loss: 186.8393 MAE: 9.6946 Mean Error: -1.0850 STD: 13.6332​

以上这篇Pytorch 神经网络—自定义数据集上实现教程就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python urlopen 使用小示例
Sep 06 Python
Python类方法__init__和__del__构造、析构过程分析
Mar 06 Python
介绍Python中几个常用的类方法
Apr 08 Python
Python中几种操作字符串的方法的介绍
Apr 09 Python
Python实现二维数组按照某行或列排序的方法【numpy lexsort】
Sep 22 Python
Python简单实现控制电脑的方法
Jan 22 Python
详解Python Matplot中文显示完美解决方案
Mar 07 Python
Django+uni-app实现数据通信中的请求跨域的示例代码
Oct 12 Python
python 表格打印代码实例解析
Oct 12 Python
django xadmin action兼容自定义model权限教程
Mar 30 Python
jupyter 导入csv文件方式
Apr 21 Python
Python爬虫基础初探selenium
May 31 Python
浅谈Python访问MySQL的正确姿势
Jan 07 #Python
pytorch自定义二值化网络层方式
Jan 07 #Python
Pytorch: 自定义网络层实例
Jan 07 #Python
Python StringIO如何在内存中读写str
Jan 07 #Python
Python内置数据类型list各方法的性能测试过程解析
Jan 07 #Python
python模拟实现斗地主发牌
Jan 07 #Python
python全局变量引用与修改过程解析
Jan 07 #Python
You might like
PHP 文本文章分页代码 按标记或长度(不涉及数据库)
2012/06/07 PHP
php导出word格式数据的代码实例
2013/11/25 PHP
phpcms手机内容页面添加上一篇和下一篇
2015/06/05 PHP
php ci 获取表单中多个同名input元素值的代码
2016/03/25 PHP
PHP删除字符串中非字母数字字符方法总结
2019/01/20 PHP
PHP实现提取多维数组指定一列的方法总结
2019/12/04 PHP
jquery复选框CHECKBOX全选、反选
2008/08/30 Javascript
JS DOM 操作实现代码
2010/08/01 Javascript
jquery html动态生成select标签出问题的解决方法
2013/11/20 Javascript
js获取url中指定参数值的示例代码
2013/12/14 Javascript
二叉树先序遍历的非递归算法具体实现
2014/01/09 Javascript
判断字符串的长度(优化版)中文占两个字符
2014/10/30 Javascript
javascript实现时间格式输出FormatDate函数
2015/01/13 Javascript
JS实现统计复选框选中个数并提示确定与取消的方法
2015/07/01 Javascript
javascript实现方法调用与方法触发小结
2016/03/26 Javascript
微信小程序 wx:key详细介绍
2016/10/28 Javascript
ES6基础之数组和对象的拓展实例详解
2019/08/22 Javascript
JS实现打字游戏
2019/12/17 Javascript
js屏蔽F12审查元素,禁止修改页面代码等实现代码
2020/10/02 Javascript
vue从后台渲染文章列表以及根据id跳转文章详情详解
2020/12/14 Vue.js
[36:20]完美世界DOTA2联赛PWL S3 access vs Rebirth 第一场 12.17
2020/12/18 DOTA
python实现的解析crontab配置文件代码
2014/06/30 Python
Python遍历目录并批量更换文件名和目录名的方法
2016/09/19 Python
浅谈使用Python内置函数getattr实现分发模式
2018/01/22 Python
PyQt5主窗口动态加载Widget实例代码
2018/02/07 Python
使用Python轻松完成垃圾分类(基于图像识别)
2019/07/09 Python
简单了解python PEP的一些知识
2019/07/13 Python
python 解压、复制、删除 文件的实例代码
2020/02/26 Python
css3实现wifi信号逐渐增强效果实例
2017/08/09 HTML / CSS
德国低价购买灯具和家具网站:Style-home.de
2016/11/25 全球购物
Bergfreunde丹麦:登山装备网上零售商
2017/02/26 全球购物
美国最大的在线生存商店:Survival Frog
2020/12/13 全球购物
化工专业应届生求职信
2013/11/08 职场文书
致百米运动员广播稿5篇
2014/10/13 职场文书
迟到检讨书
2015/01/26 职场文书
win11电脑关机鼠标灯还亮怎么解决? win11关机后鼠标灯还亮解决方法
2023/01/09 数码科技