pytorch训练神经网络爆内存的解决方案


Posted in Python onMay 22, 2021

训练的时候内存一直在增加,最后内存爆满,被迫中断。

pytorch训练神经网络爆内存的解决方案

后来换了一个电脑发现还是这样,考虑是代码的问题。

检查才发现我的代码两次存了loss,只有一个地方写的是loss.item()。问题就在loss,因为loss是variable类型。

要写成loss_train = loss_train + loss.item(),不能直接写loss_train = loss_train + loss。否则就会发现随着epoch的增加,占的内存也在一点一点增加。

算是一个小坑吧,希望大家还是要仔细。

补充:pytorch神经网络解决回归问题(非常易懂)

对于pytorch的深度学习框架

在建立人工神经网络时整体的步骤主要有以下四步:

1、载入原始数据

2、构建具体神经网络

3、进行数据的训练

4、数据测试和验证

pytorch神经网络的数据载入,以MINIST书写字体的原始数据为例:

import torch
import matplotlib.pyplot as  plt
def plot_curve(data):
    fig=plt.figure()
    plt.plot(range(len(data)),data,color="blue")
    plt.legend(["value"],loc="upper right")
    plt.xlabel("step")
    plt.ylabel("value")
    plt.show()
 
def plot_image(img,label,name):
    fig=plt.figure()
    for i in range(6):
        plt.subplot(2,3,i+1)
        plt.tight_layout()
        plt.imshow(img[i][0]*0.3081+0.1307,cmap="gray",interpolation="none")
        plt.title("{}:{}".format(name, label[i].item()))
        plt.xticks([])
        plt.yticks([])
    plt.show()
def one_hot(label,depth=10):
    out=torch.zeros(label.size(0),depth)
    idx=torch.LongTensor(label).view(-1,1)
    out.scatter_(dim=1,index=idx,value=1)
    return out
 
batch_size=512
import torch
from torch import nn                         #完成神经网络的构建包
from torch.nn import functional as F         #包含常用的函数包
from torch import optim                      #优化工具包
import torchvision                           #视觉工具包
import  matplotlib.pyplot as plt
from utils import plot_curve,plot_image,one_hot
#step1 load dataset   加载数据包
train_loader=torch.utils.data.DataLoader(
    torchvision.datasets.MNIST("minist_data",train=True,download=True,transform=torchvision.transforms.Compose(
        [torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,),(0.3081,))
         ])),
    batch_size=batch_size,shuffle=True)
test_loader=torch.utils.data.DataLoader(
    torchvision.datasets.MNIST("minist_data",train=True,download=False,transform=torchvision.transforms.Compose(
        [torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,),(0.3081,))
         ])),
    batch_size=batch_size,shuffle=False)
x,y=next(iter(train_loader))
print(x.shape,y.shape)
plot_image(x,y,"image")
print(x)
print(y)

以构建一个简单的回归问题的神经网络为例,

其具体的实现代码如下所示:

import torch
import torch.nn.functional as F  # 激励函数都在这
 
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)  # x data (tensor), shape=(100, 1)
y = x.pow(2) + 0.2 * torch.rand(x.size())  # noisy y data (tensor), shape=(100, 1)
 
class Net(torch.nn.Module):  # 继承 torch 的 Module(固定)
    def __init__(self, n_feature, n_hidden, n_output):  # 定义层的信息,n_feature多少个输入, n_hidden每层神经元, n_output多少个输出
        super(Net, self).__init__()  # 继承 __init__ 功能(固定)
        # 定义每层用什么样的形式
        self.hidden = torch.nn.Linear(n_feature, n_hidden)  # 定义隐藏层,线性输出
        self.predict = torch.nn.Linear(n_hidden, n_output)  # 定义输出层线性输出
 
    def forward(self, x):  # x是输入信息就是data,同时也是 Module 中的 forward 功能,定义神经网络前向传递的过程,把__init__中的层信息一个一个的组合起来
        # 正向传播输入值, 神经网络分析出输出值
        x = F.relu(self.hidden(x))  # 定义激励函数(隐藏层的线性值)
        x = self.predict(x)  # 输出层,输出值
        return x 
 
net = Net(n_feature=1, n_hidden=10, n_output=1) 
print(net)  # net 的结构
"""
Net (
  (hidden): Linear (1 -> 10)
  (predict): Linear (10 -> 1)
)
"""
# optimizer 是训练的工具
optimizer = torch.optim.SGD(net.parameters(), lr=0.2)  # 传入 net 的所有参数, 学习率
loss_func = torch.nn.MSELoss()  # 预测值和真实值的误差计算公式 (均方差)
 
for t in range(100):  # 训练的步数100步
    prediction = net(x)  # 喂给 net 训练数据 x, 每迭代一步,输出预测值
 
    loss = loss_func(prediction, y)  # 计算两者的误差
 
    # 优化步骤:
    optimizer.zero_grad()  # 清空上一步的残余更新参数值
    loss.backward()  # 误差反向传播, 计算参数更新值
    optimizer.step()  # 将参数更新值施加到 net 的 parameters 上
 
import matplotlib.pyplot as plt 
plt.ion()  # 实时画图something about plotting 
for t in range(200):
    prediction = net(x)  # input x and predict based on x 
    loss = loss_func(prediction, y)  # must be (1. nn output, 2. target) 
    optimizer.zero_grad()  # clear gradients for next train
    loss.backward()  # backpropagation, compute gradients
    optimizer.step()  # apply gradients
 
    if t % 5 == 0:  # 每五步绘一次图
        # plot and show learning process
        plt.cla()
        plt.scatter(x.data.numpy(), y.data.numpy())
        plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
        plt.text(0.5, 0, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 20, 'color': 'red'})
        plt.pause(0.1)
 
plt.ioff()
plt.show()

pytorch训练神经网络爆内存的解决方案

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python获取指定路径下所有指定后缀文件的方法
May 26 Python
在Lighttpd服务器中运行Django应用的方法
Jul 22 Python
python 通过字符串调用对象属性或方法的实例讲解
Apr 21 Python
python处理DICOM并计算三维模型体积
Feb 26 Python
Python中的random.uniform()函数教程与实例解析
Mar 02 Python
python中matplotlib条件背景颜色的实现
Sep 02 Python
python 的 openpyxl模块 读取 Excel文件的方法
Sep 09 Python
利用jupyter网页版本进行python函数查询方式
Apr 14 Python
导致python中import错误的原因是什么
Jul 01 Python
Python3爬虫mitmproxy的安装步骤
Jul 29 Python
Django启动时找不到mysqlclient问题解决方案
Nov 11 Python
Python3中的tuple函数知识点讲解
Jan 03 Python
粗暴解决CUDA out of memory的问题
May 22 #Python
pytorch中的model.eval()和BN层的使用
May 22 #Python
解决Pytorch中关于model.eval的问题
Pytorch 中net.train 和 net.eval的使用说明
May 22 #Python
对PyTorch中inplace字段的全面理解
May 22 #Python
pytorch中F.avg_pool1d()和F.avg_pool2d()的使用操作
May 22 #Python
用python实现监控视频人数统计
You might like
PHP4与PHP5的时间格式问题
2008/02/17 PHP
PHP 超链接 抓取实现代码
2009/06/29 PHP
php错误、异常处理机制(补充)
2012/05/07 PHP
PHP操作MongoDB实现增删改查功能【附php7操作MongoDB方法】
2018/04/24 PHP
ThinkPHP3.2框架自定义配置和加载用法示例
2018/06/14 PHP
阿里云的WindowsServer2016上部署php+apache
2018/07/17 PHP
JS案例分享之金额小写转大写
2014/05/15 Javascript
Javascript 学习笔记之 对象篇(二) : 原型对象
2014/06/24 Javascript
javascript实现获取浏览器版本、操作系统类型
2015/01/29 Javascript
个人总结的一些JavaScript技巧、实用函数、简洁方法、编程细节
2015/06/10 Javascript
浅谈JavaScript中小数和大整数的精度丢失
2016/05/31 Javascript
bootstrap table之通用方法( 时间控件,导出,动态下拉框, 表单验证 ,选中与获取信息)代码分享
2017/01/24 Javascript
Extjs表单输入框异步校验的插件实现方法
2017/03/20 Javascript
bootstrap-table.js扩展分页工具栏(增加跳转到xx页)功能
2017/12/28 Javascript
vue2.0$nextTick监听数据渲染完成之后的回调函数方法
2018/09/11 Javascript
Python中的TCP socket写法示例
2018/05/11 Python
python 自动去除空行的实例
2018/07/24 Python
python简易实现任意位数的水仙花实例
2018/11/13 Python
从列表或字典创建Pandas的DataFrame对象的方法
2019/07/06 Python
python增加图像对比度的方法
2019/07/12 Python
python操作excel让工作自动化
2019/08/09 Python
利用Tensorflow构建和训练自己的CNN来做简单的验证码识别方式
2020/01/20 Python
python通过文本在一个图中画多条线的实例
2020/02/21 Python
python中线程和进程有何区别
2020/06/17 Python
python基于socket模拟实现ssh远程执行命令
2020/12/05 Python
阿根廷旅游网站:almundo阿根廷
2018/02/12 全球购物
毕业生求职简历中的自我评价
2013/10/18 职场文书
自荐信写法介绍
2014/01/25 职场文书
2014年寒假社会实践活动心得体会
2014/04/07 职场文书
敬老模范事迹
2014/05/21 职场文书
党支部班子“四风”问题自我剖析材料
2014/09/28 职场文书
写给同学的新学期寄语
2015/02/27 职场文书
求职意向书范本
2015/05/11 职场文书
2015年度招聘工作总结
2015/05/28 职场文书
公司员工离职感言
2015/08/03 职场文书
一波干货,会议主持词开场白范文
2019/05/06 职场文书