Python使用numpy实现BP神经网络


Posted in Python onMarch 10, 2018

本文完全利用numpy实现一个简单的BP神经网络,由于是做regression而不是classification,因此在这里输出层选取的激励函数就是f(x)=x。BP神经网络的具体原理此处不再介绍。

import numpy as np 
 
class NeuralNetwork(object): 
  def __init__(self, input_nodes, hidden_nodes, output_nodes, learning_rate): 
    # Set number of nodes in input, hidden and output layers.设定输入层、隐藏层和输出层的node数目 
    self.input_nodes = input_nodes 
    self.hidden_nodes = hidden_nodes 
    self.output_nodes = output_nodes 
 
    # Initialize weights,初始化权重和学习速率 
    self.weights_input_to_hidden = np.random.normal(0.0, self.hidden_nodes**-0.5,  
                    ( self.hidden_nodes, self.input_nodes)) 
 
    self.weights_hidden_to_output = np.random.normal(0.0, self.output_nodes**-0.5,  
                    (self.output_nodes, self.hidden_nodes)) 
    self.lr = learning_rate 
     
    # 隐藏层的激励函数为sigmoid函数,Activation function is the sigmoid function 
    self.activation_function = (lambda x: 1/(1 + np.exp(-x))) 
   
  def train(self, inputs_list, targets_list): 
    # Convert inputs list to 2d array 
    inputs = np.array(inputs_list, ndmin=2).T  # 输入向量的shape为 [feature_diemension, 1] 
    targets = np.array(targets_list, ndmin=2).T  
 
    # 向前传播,Forward pass 
    # TODO: Hidden layer 
    hidden_inputs = np.dot(self.weights_input_to_hidden, inputs) # signals into hidden layer 
    hidden_outputs = self.activation_function(hidden_inputs) # signals from hidden layer 
 
     
    # 输出层,输出层的激励函数就是 y = x 
    final_inputs = np.dot(self.weights_hidden_to_output, hidden_outputs) # signals into final output layer 
    final_outputs = final_inputs # signals from final output layer 
     
    ### 反向传播 Backward pass,使用梯度下降对权重进行更新 ### 
     
    # 输出误差 
    # Output layer error is the difference between desired target and actual output. 
    output_errors = (targets_list-final_outputs) 
 
    # 反向传播误差 Backpropagated error 
    # errors propagated to the hidden layer 
    hidden_errors = np.dot(output_errors, self.weights_hidden_to_output)*(hidden_outputs*(1-hidden_outputs)).T 
 
    # 更新权重 Update the weights 
    # 更新隐藏层与输出层之间的权重 update hidden-to-output weights with gradient descent step 
    self.weights_hidden_to_output += output_errors * hidden_outputs.T * self.lr 
    # 更新输入层与隐藏层之间的权重 update input-to-hidden weights with gradient descent step 
    self.weights_input_to_hidden += (inputs * hidden_errors * self.lr).T 
  
  # 进行预测   
  def run(self, inputs_list): 
    # Run a forward pass through the network 
    inputs = np.array(inputs_list, ndmin=2).T 
     
    #### 实现向前传播 Implement the forward pass here #### 
    # 隐藏层 Hidden layer 
    hidden_inputs = np.dot(self.weights_input_to_hidden, inputs) # signals into hidden layer 
    hidden_outputs = self.activation_function(hidden_inputs) # signals from hidden layer 
     
    # 输出层 Output layer 
    final_inputs = np.dot(self.weights_hidden_to_output, hidden_outputs) # signals into final output layer 
    final_outputs = final_inputs # signals from final output layer  
     
    return final_outputs

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python进阶教程之模块(module)介绍
Aug 30 Python
基于python元祖与字典与集合的粗浅认识
Aug 23 Python
python测试mysql写入性能完整实例
Jan 18 Python
python实现决策树分类
Aug 30 Python
通过cmd进入python的实例操作
Jun 26 Python
python3实现猜数字游戏
Dec 07 Python
django框架使用方法详解
Jul 18 Python
wxpython绘制圆角窗体
Nov 18 Python
Python接口测试文件上传实例解析
May 22 Python
彻底解决pip下载pytorch慢的问题方法
Mar 01 Python
Matlab求解数组中的最大值及它所在的具体位置
Apr 16 Python
Django实现WebSocket在线聊天室功能(channels库)
Sep 25 Python
python实现日常记账本小程序
Mar 10 #Python
python实现简单神经网络算法
Mar 10 #Python
TensorFlow saver指定变量的存取
Mar 10 #Python
TensorFLow用Saver保存和恢复变量
Mar 10 #Python
tensorflow创建变量以及根据名称查找变量
Mar 10 #Python
Python2中文处理纪要的实现方法
Mar 10 #Python
python实现冒泡排序算法的两种方法
Mar 10 #Python
You might like
一个用php实现的获取URL信息的类
2007/01/02 PHP
PHP has encountered an Access Violation at 7C94BD02解决方法
2009/08/24 PHP
PHP+SQL 注入攻击的技术实现以及预防办法
2011/01/27 PHP
php数组指针操作详解
2017/02/14 PHP
js类中获取外部函数名的方法
2007/08/19 Javascript
关于 byval 与 byref 的区别分析总结
2007/10/08 Javascript
通过event对象的fromElement属性解决热区设置主实体的一个bug
2008/12/22 Javascript
JS input文本框禁用右键和复制粘贴功能的代码
2010/04/15 Javascript
javascript的console.log()用法小结
2012/05/31 Javascript
js中的push和join方法使用介绍
2013/10/08 Javascript
7个有用的jQuery代码片段分享
2015/05/19 Javascript
javascript实现获取服务器时间
2015/05/19 Javascript
jquery实现清新实用的网页菜单效果
2015/08/28 Javascript
javascript性能优化之DOM交互操作实例分析
2015/12/12 Javascript
jQuery简单实现彩色云标签效果示例
2016/08/01 Javascript
node.js实现博客小爬虫的实例代码
2016/10/08 Javascript
JS复制对应id的内容到粘贴板(Ctrl+C效果)
2017/01/23 Javascript
Vue入门之数据绑定(小结)
2018/01/08 Javascript
简单介绍react redux的中间件的使用
2018/04/06 Javascript
小程序实现上下移动切换位置
2019/09/23 Javascript
基于html+css+js实现简易计算器代码实例
2020/02/28 Javascript
javascript实现前端成语点击验证
2020/06/24 Javascript
vue npm install 安装某个指定的版本操作
2020/08/11 Javascript
使用Python来开发Markdown脚本扩展的实例分享
2016/03/04 Python
Python3.X 线程中信号量的使用方法示例
2017/07/24 Python
python画折线图的程序
2018/07/26 Python
python K近邻算法的kd树实现
2018/09/06 Python
python实现上传文件到linux指定目录的方法
2020/01/03 Python
CSS3 RGBA色彩模式使用实例讲解
2016/04/26 HTML / CSS
5分钟实现Canvas鼠标跟随动画背景
2019/11/18 HTML / CSS
菲律宾购物网站:Lazada菲律宾
2018/04/05 全球购物
美国在线健康和美容市场:Pharmapacks
2018/12/05 全球购物
公司财务工作总结的自我评价
2013/11/23 职场文书
美术教师求职信范文
2015/03/20 职场文书
企业转让协议书(范文2篇)
2019/08/15 职场文书
python基础之模块的导入
2021/10/24 Python