pytorch 6 batch_train 批训练操作


Posted in Python onMay 28, 2021

看代码吧~

import torch
import torch.utils.data as Data
torch.manual_seed(1)    # reproducible
# BATCH_SIZE = 5  
BATCH_SIZE = 8      # 每次使用8个数据同时传入网路
x = torch.linspace(1, 10, 10)       # this is x data (torch tensor)
y = torch.linspace(10, 1, 10)       # this is y data (torch tensor)
torch_dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader(
    dataset=torch_dataset,      # torch TensorDataset format
    batch_size=BATCH_SIZE,      # mini batch size
    shuffle=False,              # 设置不随机打乱数据 random shuffle for training
    num_workers=2,              # 使用两个进程提取数据,subprocesses for loading data
)
def show_batch():
    for epoch in range(3):   # 全部的数据使用3遍,train entire dataset 3 times
        for step, (batch_x, batch_y) in enumerate(loader):  # for each training step
            # train your data...
            print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',
                  batch_x.numpy(), '| batch y: ', batch_y.numpy())
if __name__ == '__main__':
    show_batch()

BATCH_SIZE = 8 , 所有数据利用三次

Epoch:  0 | Step:  0 | batch x:  [1. 2. 3. 4. 5. 6. 7. 8.] | batch y:  [10.  9.  8.  7.  6.  5.  4.  3.]
Epoch:  0 | Step:  1 | batch x:  [ 9. 10.] | batch y:  [2. 1.]
Epoch:  1 | Step:  0 | batch x:  [1. 2. 3. 4. 5. 6. 7. 8.] | batch y:  [10.  9.  8.  7.  6.  5.  4.  3.]
Epoch:  1 | Step:  1 | batch x:  [ 9. 10.] | batch y:  [2. 1.]
Epoch:  2 | Step:  0 | batch x:  [1. 2. 3. 4. 5. 6. 7. 8.] | batch y:  [10.  9.  8.  7.  6.  5.  4.  3.]
Epoch:  2 | Step:  1 | batch x:  [ 9. 10.] | batch y:  [2. 1.]

补充:pytorch批训练bug

问题描述:

在进行pytorch神经网络批训练的时候,有时会出现报错 

TypeError: batch must contain tensors, numbers, dicts or lists; found <class 'torch.autograd.variable.Variable'>

解决办法:

第一步:

检查(重点!!!!!):

train_dataset = Data.TensorDataset(train_x, train_y)

train_x,和train_y格式,要求是tensor类,我第一次出错就是因为传入的是variable

可以这样将数据变为tensor类:

train_x = torch.FloatTensor(train_x)

第二步:

train_loader = Data.DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        shuffle=True
    )

实例化一个DataLoader对象

第三步:

for epoch in range(epochs):
        for step, (batch_x, batch_y) in enumerate(train_loader):
            batch_x, batch_y = Variable(batch_x), Variable(batch_y)

这样就可以批训练了

需要注意的是:train_loader输出的是tensor,在训练网络时,需要变成Variable

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现的计算马氏距离算法示例
Apr 03 Python
对numpy中的数组条件筛选功能详解
Jul 02 Python
对python列表里的字典元素去重方法详解
Jan 21 Python
Python+OpenCV图片局部区域像素值处理改进版详解
Jan 23 Python
Python实现简单查找最长子串功能示例
Feb 26 Python
python中append实例用法总结
Jul 30 Python
Python使用selenium + headless chrome获取网页内容的方法示例
Oct 16 Python
numpy:找到指定元素的索引示例
Nov 26 Python
Python编译为二进制so可执行文件实例
Dec 23 Python
Django ForeignKey与数据库的FOREIGN KEY约束详解
May 20 Python
C++和python实现阿姆斯特朗数字查找实例代码
Dec 07 Python
利用python进行数据加载
Jun 20 Python
pytorch 如何使用batch训练lstm网络
May 28 #Python
使用Pytorch训练two-head网络的操作
May 28 #Python
使用Python的开发框架Brownie部署以太坊智能合约
使用Pytorch实现two-head(多输出)模型的操作
8g内存用python读取10文件_面试题-python 如何读取一个大于 10G 的txt文件?
用python画城市轮播地图
用Python实现一个打字速度测试工具来测试你的手速
You might like
php smarty 二级分类代码和模版循环例子
2011/06/16 PHP
php使用curl检测网页是否被百度收录的示例分享
2014/01/31 PHP
PHP数据库万能引擎类adodb配置使用以及实例集锦
2014/06/12 PHP
跨浏览器PHP下载文件名中的中文乱码问题解决方法
2015/03/05 PHP
如何将JS的变量值传递给ASP变量
2012/12/10 Javascript
jquery实现的可隐藏重现的靠边悬浮层实例代码
2013/05/27 Javascript
window.location 对象所包含的属性
2014/10/10 Javascript
jQuery在页面加载时动态修改图片尺寸的方法
2015/03/20 Javascript
jquery实现图片水平滚动效果代码分享
2015/08/26 Javascript
js图片轮播特效代码分享
2015/09/07 Javascript
JS+CSS实现简易实用的滑动门菜单效果
2015/09/18 Javascript
NodeJS创建基础应用并应用模板引擎
2016/04/12 NodeJs
JS组件Bootstrap导航条使用方法详解
2016/04/29 Javascript
详解Vue路由开启keep-alive时的注意点
2017/06/20 Javascript
原生JS实现移动端web轮播图详解(结合Tween算法造轮子)
2017/09/10 Javascript
vue脚手架及vue-router基本使用
2018/04/09 Javascript
使用javascript做时间倒数读秒功能的实例
2019/01/23 Javascript
javascript获取select值的方法完整实例
2019/06/20 Javascript
vue实现简单的日历效果
2020/09/24 Javascript
vue实现文字加密功能
2019/09/27 Javascript
JS回调函数简单易懂的入门实例分析
2019/09/29 Javascript
纯js+css实现仿移动端淘宝网站的弹出详情框功能
2019/12/29 Javascript
JS中FormData类实现文件上传
2020/03/27 Javascript
javaScript实现一个队列的方法
2020/07/14 Javascript
Python tkinter模块中类继承的三种方式分析
2017/08/08 Python
浅谈Python编程中3个常用的数据结构和算法
2019/04/30 Python
解决django中form表单设置action后无法回到原页面的问题
2020/03/13 Python
Python类中的装饰器在当前类中的声明与调用详解
2020/04/15 Python
使用SVG实现提示框功能的示例代码
2020/06/05 HTML / CSS
北承题目(C++)
2012/05/16 面试题
终端业务员岗位职责
2013/11/27 职场文书
纪念九一八事变演讲稿:勿忘国耻
2014/09/14 职场文书
实训报告范文大全
2014/11/04 职场文书
员工工作能力评语
2014/12/31 职场文书
人事行政部各岗位职责说明书!
2019/07/15 职场文书
设置IIS Express并发数
2022/07/07 Servers