pytorch训练神经网络爆内存的解决方案


Posted in Python onMay 22, 2021

训练的时候内存一直在增加,最后内存爆满,被迫中断。

pytorch训练神经网络爆内存的解决方案

后来换了一个电脑发现还是这样,考虑是代码的问题。

检查才发现我的代码两次存了loss,只有一个地方写的是loss.item()。问题就在loss,因为loss是variable类型。

要写成loss_train = loss_train + loss.item(),不能直接写loss_train = loss_train + loss。否则就会发现随着epoch的增加,占的内存也在一点一点增加。

算是一个小坑吧,希望大家还是要仔细。

补充:pytorch神经网络解决回归问题(非常易懂)

对于pytorch的深度学习框架

在建立人工神经网络时整体的步骤主要有以下四步:

1、载入原始数据

2、构建具体神经网络

3、进行数据的训练

4、数据测试和验证

pytorch神经网络的数据载入,以MINIST书写字体的原始数据为例:

import torch
import matplotlib.pyplot as  plt
def plot_curve(data):
    fig=plt.figure()
    plt.plot(range(len(data)),data,color="blue")
    plt.legend(["value"],loc="upper right")
    plt.xlabel("step")
    plt.ylabel("value")
    plt.show()
 
def plot_image(img,label,name):
    fig=plt.figure()
    for i in range(6):
        plt.subplot(2,3,i+1)
        plt.tight_layout()
        plt.imshow(img[i][0]*0.3081+0.1307,cmap="gray",interpolation="none")
        plt.title("{}:{}".format(name, label[i].item()))
        plt.xticks([])
        plt.yticks([])
    plt.show()
def one_hot(label,depth=10):
    out=torch.zeros(label.size(0),depth)
    idx=torch.LongTensor(label).view(-1,1)
    out.scatter_(dim=1,index=idx,value=1)
    return out
 
batch_size=512
import torch
from torch import nn                         #完成神经网络的构建包
from torch.nn import functional as F         #包含常用的函数包
from torch import optim                      #优化工具包
import torchvision                           #视觉工具包
import  matplotlib.pyplot as plt
from utils import plot_curve,plot_image,one_hot
#step1 load dataset   加载数据包
train_loader=torch.utils.data.DataLoader(
    torchvision.datasets.MNIST("minist_data",train=True,download=True,transform=torchvision.transforms.Compose(
        [torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,),(0.3081,))
         ])),
    batch_size=batch_size,shuffle=True)
test_loader=torch.utils.data.DataLoader(
    torchvision.datasets.MNIST("minist_data",train=True,download=False,transform=torchvision.transforms.Compose(
        [torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,),(0.3081,))
         ])),
    batch_size=batch_size,shuffle=False)
x,y=next(iter(train_loader))
print(x.shape,y.shape)
plot_image(x,y,"image")
print(x)
print(y)

以构建一个简单的回归问题的神经网络为例,

其具体的实现代码如下所示:

import torch
import torch.nn.functional as F  # 激励函数都在这
 
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)  # x data (tensor), shape=(100, 1)
y = x.pow(2) + 0.2 * torch.rand(x.size())  # noisy y data (tensor), shape=(100, 1)
 
class Net(torch.nn.Module):  # 继承 torch 的 Module(固定)
    def __init__(self, n_feature, n_hidden, n_output):  # 定义层的信息,n_feature多少个输入, n_hidden每层神经元, n_output多少个输出
        super(Net, self).__init__()  # 继承 __init__ 功能(固定)
        # 定义每层用什么样的形式
        self.hidden = torch.nn.Linear(n_feature, n_hidden)  # 定义隐藏层,线性输出
        self.predict = torch.nn.Linear(n_hidden, n_output)  # 定义输出层线性输出
 
    def forward(self, x):  # x是输入信息就是data,同时也是 Module 中的 forward 功能,定义神经网络前向传递的过程,把__init__中的层信息一个一个的组合起来
        # 正向传播输入值, 神经网络分析出输出值
        x = F.relu(self.hidden(x))  # 定义激励函数(隐藏层的线性值)
        x = self.predict(x)  # 输出层,输出值
        return x 
 
net = Net(n_feature=1, n_hidden=10, n_output=1) 
print(net)  # net 的结构
"""
Net (
  (hidden): Linear (1 -> 10)
  (predict): Linear (10 -> 1)
)
"""
# optimizer 是训练的工具
optimizer = torch.optim.SGD(net.parameters(), lr=0.2)  # 传入 net 的所有参数, 学习率
loss_func = torch.nn.MSELoss()  # 预测值和真实值的误差计算公式 (均方差)
 
for t in range(100):  # 训练的步数100步
    prediction = net(x)  # 喂给 net 训练数据 x, 每迭代一步,输出预测值
 
    loss = loss_func(prediction, y)  # 计算两者的误差
 
    # 优化步骤:
    optimizer.zero_grad()  # 清空上一步的残余更新参数值
    loss.backward()  # 误差反向传播, 计算参数更新值
    optimizer.step()  # 将参数更新值施加到 net 的 parameters 上
 
import matplotlib.pyplot as plt 
plt.ion()  # 实时画图something about plotting 
for t in range(200):
    prediction = net(x)  # input x and predict based on x 
    loss = loss_func(prediction, y)  # must be (1. nn output, 2. target) 
    optimizer.zero_grad()  # clear gradients for next train
    loss.backward()  # backpropagation, compute gradients
    optimizer.step()  # apply gradients
 
    if t % 5 == 0:  # 每五步绘一次图
        # plot and show learning process
        plt.cla()
        plt.scatter(x.data.numpy(), y.data.numpy())
        plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
        plt.text(0.5, 0, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 20, 'color': 'red'})
        plt.pause(0.1)
 
plt.ioff()
plt.show()

pytorch训练神经网络爆内存的解决方案

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用在线API查询IP对应的地理位置信息实例
Jun 01 Python
Python中死锁的形成示例及死锁情况的防止
Jun 14 Python
利用numpy+matplotlib绘图的基本操作教程
May 03 Python
磁盘垃圾文件清理器python代码实现
Aug 24 Python
Python设计模式之适配器模式原理与用法详解
Jan 15 Python
Python爬虫实现“盗取”微信好友信息的方法分析
Sep 16 Python
git查看、创建、删除、本地、远程分支方法详解
Feb 18 Python
Python ini文件常用操作方法解析
Apr 26 Python
python向xls写入数据(包括合并,边框,对齐,列宽)
Feb 02 Python
在pycharm中无法import所安装的库解决方案
May 31 Python
Python 语言实现六大查找算法
Jun 30 Python
PO模式在selenium自动化测试框架的优势
Mar 20 Python
粗暴解决CUDA out of memory的问题
May 22 #Python
pytorch中的model.eval()和BN层的使用
May 22 #Python
解决Pytorch中关于model.eval的问题
Pytorch 中net.train 和 net.eval的使用说明
May 22 #Python
对PyTorch中inplace字段的全面理解
May 22 #Python
pytorch中F.avg_pool1d()和F.avg_pool2d()的使用操作
May 22 #Python
用python实现监控视频人数统计
You might like
php学习笔记之 函数声明
2011/06/09 PHP
php内核解析:PHP中的哈希表
2014/01/30 PHP
php常用数组array函数实例总结【赋值,拆分,合并,计算,添加,删除,查询,判断,排序】
2016/12/07 PHP
Javascript中的var_dump函数实现代码
2009/09/07 Javascript
jquery ajax 局部无刷新更新数据的实现案例
2014/02/08 Javascript
利用函数的惰性载入提高javascript代码执行效率
2014/05/05 Javascript
基于jQuery实现表单提交验证
2014/11/24 Javascript
JavaScript中数组的合并以及排序实现示例
2015/10/24 Javascript
概述javascript在Google IE中的调试技巧
2016/11/24 Javascript
bootstrap confirmation按钮提示组件使用详解
2017/08/22 Javascript
vue实现五子棋游戏
2020/05/28 Javascript
[01:20]2018DOTA2亚洲邀请赛总决赛战队LGD晋级之路
2018/04/07 DOTA
Python实现快速多线程ping的方法
2015/07/15 Python
Python机器学习之决策树算法实例详解
2017/12/06 Python
将python代码和注释分离的方法
2018/04/21 Python
python分批定量读取文件内容,输出到不同文件中的方法
2018/12/08 Python
在Python中表示一个对象的方法
2019/06/25 Python
Python实现的远程文件自动打包并下载功能示例
2019/07/12 Python
python中下标和切片的使用方法解析
2019/08/27 Python
PYQT5 vscode联合操作qtdesigner的方法
2020/03/24 Python
Jupyter Notebook输出矢量图实例
2020/04/14 Python
python装饰器代码深入讲解
2021/03/01 Python
什么是Smart Navigation?
2016/07/03 面试题
化学学院毕业生自荐信范文
2013/12/17 职场文书
大学生职业生涯规划书模板
2014/01/18 职场文书
七一党建活动方案
2014/01/28 职场文书
商场端午节活动方案
2014/01/29 职场文书
物业总经理岗位职责
2014/02/28 职场文书
领导干部贪图享乐整改措施
2014/09/21 职场文书
高中生旷课检讨书
2014/10/08 职场文书
教师师德师风整改措施
2014/10/24 职场文书
师德师风个人整改措施
2014/10/27 职场文书
红色经典观后感
2015/06/18 职场文书
城南旧事读书笔记
2015/06/29 职场文书
如何做好员工培训计划?
2019/07/09 职场文书
导游词之张家口
2019/12/13 职场文书