pytorch 6 batch_train 批训练操作


Posted in Python onMay 28, 2021

看代码吧~

import torch
import torch.utils.data as Data
torch.manual_seed(1)    # reproducible
# BATCH_SIZE = 5  
BATCH_SIZE = 8      # 每次使用8个数据同时传入网路
x = torch.linspace(1, 10, 10)       # this is x data (torch tensor)
y = torch.linspace(10, 1, 10)       # this is y data (torch tensor)
torch_dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader(
    dataset=torch_dataset,      # torch TensorDataset format
    batch_size=BATCH_SIZE,      # mini batch size
    shuffle=False,              # 设置不随机打乱数据 random shuffle for training
    num_workers=2,              # 使用两个进程提取数据,subprocesses for loading data
)
def show_batch():
    for epoch in range(3):   # 全部的数据使用3遍,train entire dataset 3 times
        for step, (batch_x, batch_y) in enumerate(loader):  # for each training step
            # train your data...
            print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',
                  batch_x.numpy(), '| batch y: ', batch_y.numpy())
if __name__ == '__main__':
    show_batch()

BATCH_SIZE = 8 , 所有数据利用三次

Epoch:  0 | Step:  0 | batch x:  [1. 2. 3. 4. 5. 6. 7. 8.] | batch y:  [10.  9.  8.  7.  6.  5.  4.  3.]
Epoch:  0 | Step:  1 | batch x:  [ 9. 10.] | batch y:  [2. 1.]
Epoch:  1 | Step:  0 | batch x:  [1. 2. 3. 4. 5. 6. 7. 8.] | batch y:  [10.  9.  8.  7.  6.  5.  4.  3.]
Epoch:  1 | Step:  1 | batch x:  [ 9. 10.] | batch y:  [2. 1.]
Epoch:  2 | Step:  0 | batch x:  [1. 2. 3. 4. 5. 6. 7. 8.] | batch y:  [10.  9.  8.  7.  6.  5.  4.  3.]
Epoch:  2 | Step:  1 | batch x:  [ 9. 10.] | batch y:  [2. 1.]

补充:pytorch批训练bug

问题描述:

在进行pytorch神经网络批训练的时候,有时会出现报错 

TypeError: batch must contain tensors, numbers, dicts or lists; found <class 'torch.autograd.variable.Variable'>

解决办法:

第一步:

检查(重点!!!!!):

train_dataset = Data.TensorDataset(train_x, train_y)

train_x,和train_y格式,要求是tensor类,我第一次出错就是因为传入的是variable

可以这样将数据变为tensor类:

train_x = torch.FloatTensor(train_x)

第二步:

train_loader = Data.DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        shuffle=True
    )

实例化一个DataLoader对象

第三步:

for epoch in range(epochs):
        for step, (batch_x, batch_y) in enumerate(train_loader):
            batch_x, batch_y = Variable(batch_x), Variable(batch_y)

这样就可以批训练了

需要注意的是:train_loader输出的是tensor,在训练网络时,需要变成Variable

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现简单的获取图片爬虫功能示例
Jul 12 Python
详解 Python 与文件对象共事的实例
Sep 11 Python
Python爬虫之UserAgent的使用实例
Feb 21 Python
Python、 Pycharm、Django安装详细教程(图文)
Apr 12 Python
详解如何管理多个Python版本和虚拟环境
May 10 Python
python中pip的使用和修改下载源的方法
Jul 08 Python
python内存管理机制原理详解
Aug 12 Python
使用turtle绘制五角星、分形树
Oct 06 Python
Python面向对象之继承原理与用法案例分析
Dec 31 Python
Python Sphinx使用实例及问题解决
Jan 17 Python
python实现ip地址的包含关系判断
Feb 07 Python
django ORM之values和annotate使用详解
May 19 Python
pytorch 如何使用batch训练lstm网络
May 28 #Python
使用Pytorch训练two-head网络的操作
May 28 #Python
使用Python的开发框架Brownie部署以太坊智能合约
使用Pytorch实现two-head(多输出)模型的操作
8g内存用python读取10文件_面试题-python 如何读取一个大于 10G 的txt文件?
用python画城市轮播地图
用Python实现一个打字速度测试工具来测试你的手速
You might like
漫荒推荐:画风超赞的国风漫画推荐 超长假期不无聊
2020/03/08 国漫
如何冲泡挂耳包咖啡?技巧是什么
2021/03/04 冲泡冲煮
PHP学习之输出字符串(echo,print,printf,print_r和var_dump)
2011/04/17 PHP
九个你必须知道而且又很好用的php函数和特点
2013/08/08 PHP
php5.3后静态绑定用法详解
2016/11/11 PHP
Laravel 5.4前后台分离,通过不同的二级域名访问方法
2019/10/13 PHP
JavaScript继承方式实例
2010/10/29 Javascript
jquer之ajaxQueue简单实现代码
2011/09/15 Javascript
JQuery操作表格(隔行着色,高亮显示,筛选数据)
2012/02/23 Javascript
js判断FCKeditor内容是否为空的两种形式
2013/05/14 Javascript
js出生日期 年月日级联菜单示例代码
2014/01/10 Javascript
jquery插件ajaxupload实现文件上传操作
2015/12/09 Javascript
AngularJS基础 ng-repeat 指令简单示例
2016/08/03 Javascript
js 倒计时(高效率服务器时间同步)
2017/09/12 Javascript
Vue.js最佳实践(五招助你成为vuejs大师)
2018/05/04 Javascript
使用JavaScript生成罗马字符的实例代码
2018/06/08 Javascript
Mint UI组件库CheckList使用及踩坑总结
2018/12/20 Javascript
Vue+Element实现表格编辑、删除、以及新增行的最优方法
2019/05/28 Javascript
Django应用程序中如何发送电子邮件详解
2017/02/04 Python
Python将多个excel文件合并为一个文件
2018/01/03 Python
python实现批量注册网站用户的示例
2019/02/22 Python
解决windows下python3使用multiprocessing.Pool出现的问题
2020/04/08 Python
Python使用Matlab命令过程解析
2020/06/04 Python
Python pip使用超时问题解决方案
2020/08/03 Python
Django自带的用户验证系统实现
2020/12/18 Python
css3实现可滑动跳转的分页插件示例
2014/05/08 HTML / CSS
CSS3感应鼠标的背景闪烁和图片缩放动画效果
2014/05/14 HTML / CSS
Stuart Weitzman美国官网:美国奢华鞋履品牌
2016/08/18 全球购物
英国家庭珠宝商:T. H. Baker
2018/02/08 全球购物
益模软件Java笔试题
2012/03/27 面试题
公司薪酬管理制度
2014/01/31 职场文书
金融专业毕业生自荐信
2014/06/26 职场文书
要账委托书范本
2014/09/15 职场文书
农村党支部书记党群众路线四风问题整改措施
2014/09/26 职场文书
党员干部反四风民主生活会对照检查材料思想汇报
2014/10/12 职场文书
违反学校规则制度检讨书
2015/01/01 职场文书