详解Pytorch 使用Pytorch拟合多项式(多项式回归)


Posted in Python onMay 24, 2018

使用Pytorch来编写神经网络具有很多优势,比起Tensorflow,我认为Pytorch更加简单,结构更加清晰。

希望通过实战几个Pytorch的例子,让大家熟悉Pytorch的使用方法,包括数据集创建,各种网络层结构的定义,以及前向传播与权重更新方式。

比如这里给出

详解Pytorch 使用Pytorch拟合多项式(多项式回归)    

很显然,这里我们只需要假定

详解Pytorch 使用Pytorch拟合多项式(多项式回归)

这里我们只需要设置一个合适尺寸的全连接网络,根据不断迭代,求出最接近的参数即可。

但是这里需要思考一个问题,使用全连接网络结构是毫无疑问的,但是我们的输入与输出格式是什么样的呢?

只将一个x作为输入合理吗?显然是不合理的,因为每一个神经元其实模拟的是wx+b的计算过程,无法模拟幂运算,所以显然我们需要将x,x的平方,x的三次方,x的四次方组合成一个向量作为输入,假设有n个不同的x值,我们就可以将n个组合向量合在一起组成输入矩阵。

这一步代码如下:

def make_features(x): 
 x = x.unsqueeze(1) 
 return torch.cat([x ** i for i in range(1,4)] , 1)

我们需要生成一些随机数作为网络输入:

def get_batch(batch_size=32): 
 random = torch.randn(batch_size) 
 x = make_features(random) 
 '''Compute the actual results''' 
 y = f(x) 
 if torch.cuda.is_available(): 
  return Variable(x).cuda(), Variable(y).cuda() 
 else: 
  return Variable(x), Variable(y)

其中的f(x)定义如下:

w_target = torch.FloatTensor([0.5,3,2.4]).unsqueeze(1) 
b_target = torch.FloatTensor([0.9]) 
 
def f(x): 
 return x.mm(w_target)+b_target[0]

接下来定义模型:

class poly_model(nn.Module): 
 def __init__(self): 
  super(poly_model, self).__init__() 
  self.poly = nn.Linear(3,1) 
 
 def forward(self, x): 
  out = self.poly(x) 
  return out
if torch.cuda.is_available(): 
 model = poly_model().cuda() 
else: 
 model = poly_model()

接下来我们定义损失函数和优化器:

criterion = nn.MSELoss() 
optimizer = optim.SGD(model.parameters(), lr = 1e-3)

网络部件定义完后,开始训练:

epoch = 0 
while True: 
 batch_x,batch_y = get_batch() 
 output = model(batch_x) 
 loss = criterion(output,batch_y) 
 print_loss = loss.data[0] 
 optimizer.zero_grad() 
 loss.backward() 
 optimizer.step() 
 epoch+=1 
 if print_loss < 1e-3: 
  break

到此我们的所有代码就敲完了,接下来我们开始详细了解一下其中的一些代码。

在make_features()定义中,torch.cat是将计算出的向量拼接成矩阵。unsqueeze是作一个维度上的变化。

get_batch中,torch.randn是产生指定维度的随机数,如果你的机器支持GPU加速,可以将Variable放在GPU上进行运算,类似语句含义相通。

x.mm是作矩阵乘法。

模型定义是重中之重,其实当你掌握Pytorch之后,你会发现模型定义是十分简单的,各种基本的层结构都已经为你封装好了。所有的层结构和损失函数都来自torch.nn,所有的模型构建都是从这个基类 nn.Module继承的。模型定义中,__init__与forward是有模板的,大家可以自己体会。

nn.Linear是做一个线性的运算,参数的含义代表了输入层与输出层的结构,即3*1;在训练阶段,有几行是Pytorch不同于别的框架的,首先loss是一个Variable,通过loss.data可以取出一个Tensor,再通过data[0]可以得到一个int或者float类型的值,我们才可以进行基本运算或者显示。每次计算梯度之前,都需要将梯度归零,否则梯度会叠加。个人觉得别的语句还是比较好懂的,如果有疑问可以在下方评论。

下面是我们的拟合结果

详解Pytorch 使用Pytorch拟合多项式(多项式回归)

其实效果肯定会很好,因为只是一个非常简单的全连接网络,希望大家通过这个小例子可以学到Pytorch的一些基本操作。往后我们会继续更新,完整代码请戳,https://github.com/ZhichaoDuan/PytorchCourse

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
跟老齐学Python之再深点,更懂list
Sep 20 Python
Python缩进和冒号详解
Jun 01 Python
使用python生成杨辉三角形的示例代码
Aug 29 Python
Python实现个人微信号自动监控告警的示例
Jul 03 Python
Python编写带选项的命令行程序方法
Aug 13 Python
django多种支付、并发订单处理实例代码
Dec 13 Python
Pytorch 实现sobel算子的卷积操作详解
Jan 10 Python
python常用运维脚本实例小结
Feb 14 Python
Python 序列化和反序列化库 MarshMallow 的用法实例代码
Feb 25 Python
Python 面向对象部分知识点小结
Mar 09 Python
Python -m参数原理及使用方法解析
Aug 21 Python
Python爬虫基础之爬虫的分类知识总结
May 13 Python
Python获取系统所有进程PID及进程名称的方法示例
May 24 #Python
好的Python培训机构应该具备哪些条件
May 23 #Python
Python实现的根据IP地址计算子网掩码位数功能示例
May 23 #Python
Python加载带有注释的Json文件实例
May 23 #Python
Python实现判断一行代码是否为注释的方法
May 23 #Python
对python的文件内注释 help注释方法
May 23 #Python
Python基于生成器迭代实现的八皇后问题示例
May 23 #Python
You might like
删除无限分类并同时删除它下面的所有子分类的方法
2010/08/08 PHP
tagName的使用,留一笔
2006/06/26 Javascript
兼容多浏览器的字幕特效Marquee的通用js类
2008/07/20 Javascript
Jquery中增加参数与Json转换代码
2009/11/20 Javascript
Jquery 常用方法经典总结
2010/01/28 Javascript
求数组最大最小值方法适用于任何数组
2013/08/16 Javascript
JQuery文字列表向上滚动的代码
2013/11/13 Javascript
JQuery处理json与ajax返回JSON实例代码
2014/01/03 Javascript
javascript同步服务器时间和同步倒计时小技巧
2015/09/24 Javascript
JS实现仿腾讯微博无刷新删除微博效果代码
2015/10/16 Javascript
Bootstrap编写一个在当前网页弹出可关闭的对话框 非弹窗
2016/06/30 Javascript
快速解决js动态改变dom元素属性后页面及时渲染的问题
2016/07/06 Javascript
js print打印网页指定区域内容的简单实例
2016/11/01 Javascript
React中使用collections时key的重要性详解
2017/08/07 Javascript
基于JavaScript实现前端数据多条件筛选功能
2020/08/19 Javascript
Vue-Router模式和钩子的用法
2018/02/28 Javascript
深入理解JavaScript的async/await
2018/08/05 Javascript
p5.js绘制创意自画像
2019/11/04 Javascript
微信小程序实现多选框全选与反全选及购物车中删除选中的商品功能
2019/12/17 Javascript
原生js实现文件上传、下载、封装等实例方法
2020/01/05 Javascript
原生js实现的观察者和订阅者模式简单示例
2020/04/18 Javascript
微信小程序开发(二):页面跳转并传参操作示例
2020/06/01 Javascript
在NodeJs中使用node-schedule增加定时器任务的方法
2020/06/08 NodeJs
回调函数的意义以及python实现实例
2017/06/20 Python
Tensorflow--取tensorf指定列的操作方式
2020/06/30 Python
css3 响应式媒体查询的示例代码
2019/09/25 HTML / CSS
HTML5实现锚点时请使用id取代name
2013/09/06 HTML / CSS
美国特价机票专家:Airfarewatchdog
2018/01/24 全球购物
Strathberry苏贝瑞中国官网:西班牙高级工匠手工打造
2020/10/19 全球购物
精伦电子Java笔试题
2013/01/16 面试题
村级个人对照检查材料
2014/08/22 职场文书
副总经理岗位职责
2015/02/02 职场文书
2015年高校教师个人工作总结
2015/05/25 职场文书
Python可变与不可变数据和深拷贝与浅拷贝
2022/04/06 Python
vue @ ~ 相对路径 路径别名设置方式
2022/06/05 Vue.js
MySQL慢查询中的commit慢和binlog中慢事务的区别
2022/06/16 MySQL