Python实现的NN神经网络算法完整示例


Posted in Python onJune 19, 2018

本文实例讲述了Python实现的NN神经网络算法。分享给大家供大家参考,具体如下:

参考自Github开源代码:https://github.com/dennybritz/nn-from-scratch

运行环境

  • Pyhton3
  • numpy(科学计算包)
  • matplotlib(画图所需,不画图可不必)
  • sklearn(人工智能包,生成数据使用)

计算过程

Python实现的NN神经网络算法完整示例

输入样例

none

代码实现

# -*- coding:utf-8 -*-
#!python3
__author__ = 'Wsine'
import numpy as np
import sklearn
import sklearn.datasets
import sklearn.linear_model
import matplotlib.pyplot as plt
import matplotlib
import operator
import time
def createData(dim=200, cnoise=0.20):
  """
  输出:数据集, 对应的类别标签
  描述:生成一个数据集和对应的类别标签
  """
  np.random.seed(0)
  X, y = sklearn.datasets.make_moons(dim, noise=cnoise)
  plt.scatter(X[:, 0], X[:, 1], s=40, c=y, cmap=plt.cm.Spectral)
  #plt.show()
  return X, y
def plot_decision_boundary(pred_func, X, y):
  """
  输入:边界函数, 数据集, 类别标签
  描述:绘制决策边界(画图用)
  """
  # 设置最小最大值, 加上一点外边界
  x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5
  y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5
  h = 0.01
  # 根据最小最大值和一个网格距离生成整个网格
  xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
  # 对整个网格预测边界值
  Z = pred_func(np.c_[xx.ravel(), yy.ravel()])
  Z = Z.reshape(xx.shape)
  # 绘制边界和数据集的点
  plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral)
  plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Spectral)
def calculate_loss(model, X, y):
  """
  输入:训练模型, 数据集, 类别标签
  输出:误判的概率
  描述:计算整个模型的性能
  """
  W1, b1, W2, b2 = model['W1'], model['b1'], model['W2'], model['b2']
  # 正向传播来计算预测的分类值
  z1 = X.dot(W1) + b1
  a1 = np.tanh(z1)
  z2 = a1.dot(W2) + b2
  exp_scores = np.exp(z2)
  probs = exp_scores / np.sum(exp_scores, axis=1, keepdims=True)
  # 计算误判概率
  corect_logprobs = -np.log(probs[range(num_examples), y])
  data_loss = np.sum(corect_logprobs)
  # 加入正则项修正错误(可选)
  data_loss += reg_lambda/2 * (np.sum(np.square(W1)) + np.sum(np.square(W2)))
  return 1./num_examples * data_loss
def predict(model, x):
  """
  输入:训练模型, 预测向量
  输出:判决类别
  描述:预测类别属于(0 or 1)
  """
  W1, b1, W2, b2 = model['W1'], model['b1'], model['W2'], model['b2']
  # 正向传播计算
  z1 = x.dot(W1) + b1
  a1 = np.tanh(z1)
  z2 = a1.dot(W2) + b2
  exp_scores = np.exp(z2)
  probs = exp_scores / np.sum(exp_scores, axis=1, keepdims=True)
  return np.argmax(probs, axis=1)
def initParameter(X):
  """
  输入:数据集
  描述:初始化神经网络算法的参数
     必须初始化为全局函数!
     这里需要手动设置!
  """
  global num_examples
  num_examples = len(X) # 训练集的大小
  global nn_input_dim
  nn_input_dim = 2 # 输入层维数
  global nn_output_dim
  nn_output_dim = 2 # 输出层维数
  # 梯度下降参数
  global epsilon
  epsilon = 0.01 # 梯度下降学习步长
  global reg_lambda
  reg_lambda = 0.01 # 修正的指数
def build_model(X, y, nn_hdim, num_passes=20000, print_loss=False):
  """
  输入:数据集, 类别标签, 隐藏层层数, 迭代次数, 是否输出误判率
  输出:神经网络模型
  描述:生成一个指定层数的神经网络模型
  """
  # 根据维度随机初始化参数
  np.random.seed(0)
  W1 = np.random.randn(nn_input_dim, nn_hdim) / np.sqrt(nn_input_dim)
  b1 = np.zeros((1, nn_hdim))
  W2 = np.random.randn(nn_hdim, nn_output_dim) / np.sqrt(nn_hdim)
  b2 = np.zeros((1, nn_output_dim))
  model = {}
  # 梯度下降
  for i in range(0, num_passes):
    # 正向传播
    z1 = X.dot(W1) + b1
    a1 = np.tanh(z1) # 激活函数使用tanh = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
    z2 = a1.dot(W2) + b2
    exp_scores = np.exp(z2) # 原始归一化
    probs = exp_scores / np.sum(exp_scores, axis=1, keepdims=True)
    # 后向传播
    delta3 = probs
    delta3[range(num_examples), y] -= 1
    dW2 = (a1.T).dot(delta3)
    db2 = np.sum(delta3, axis=0, keepdims=True)
    delta2 = delta3.dot(W2.T) * (1 - np.power(a1, 2))
    dW1 = np.dot(X.T, delta2)
    db1 = np.sum(delta2, axis=0)
    # 加入修正项
    dW2 += reg_lambda * W2
    dW1 += reg_lambda * W1
    # 更新梯度下降参数
    W1 += -epsilon * dW1
    b1 += -epsilon * db1
    W2 += -epsilon * dW2
    b2 += -epsilon * db2
    # 更新模型
    model = { 'W1': W1, 'b1': b1, 'W2': W2, 'b2': b2}
    # 一定迭代次数后输出当前误判率
    if print_loss and i % 1000 == 0:
      print("Loss after iteration %i: %f" % (i, calculate_loss(model, X, y)))
  plot_decision_boundary(lambda x: predict(model, x), X, y)
  plt.title("Decision Boundary for hidden layer size %d" % nn_hdim)
  #plt.show()
  return model
def main():
  dataSet, labels = createData(200, 0.20)
  initParameter(dataSet)
  nnModel = build_model(dataSet, labels, 3, print_loss=False)
  print("Loss is %f" % calculate_loss(nnModel, dataSet, labels))
if __name__ == '__main__':
  start = time.clock()
  main()
  end = time.clock()
  print('finish all in %s' % str(end - start))
  plt.show()

输出样例

Loss is 0.071316
finish all in 7.221354361552228

Python实现的NN神经网络算法完整示例

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
python读取Android permission文件
Nov 01 Python
python中文件变化监控示例(watchdog)
Oct 16 Python
[原创]Python入门教程3. 列表基本操作【定义、运算、常用函数】
Oct 30 Python
基于django传递数据到后端的例子
Aug 16 Python
解决python有时候import不了当前的包问题
Aug 28 Python
详解Python中字符串前“b”,“r”,“u”,“f”的作用
Dec 18 Python
Pytorch mask_select 函数的用法详解
Feb 18 Python
使用 pytorch 创建神经网络拟合sin函数的实现
Feb 24 Python
使用python编写一个语音朗读闹钟功能的示例代码
Jul 14 Python
Python字典fromkeys()方法使用代码实例
Jul 20 Python
详解python变量与数据类型
Aug 25 Python
Matplotlib animation模块实现动态图
Feb 25 Python
python中的二维列表实例详解
Jun 19 #Python
Tensorflow中使用tfrecord方式读取数据的方法
Jun 19 #Python
python3实现SMTP发送邮件详细教程
Jun 19 #Python
Python SVM(支持向量机)实现方法完整示例
Jun 19 #Python
Tensorflow使用tfrecord输入数据格式
Jun 19 #Python
Tensorflow 训练自己的数据集将数据直接导入到内存
Jun 19 #Python
python如何爬取个性签名
Jun 19 #Python
You might like
正义联盟的终局之战《天启星战争》将成为DC动画宇宙的最后一部
2020/04/09 欧美动漫
介绍php设计模式中的工厂模式
2008/06/12 PHP
php下通过伪造http头破解防盗链的代码
2010/07/03 PHP
Notice: Undefined index: page in E:\PHP\test.php on line 14
2010/11/02 PHP
简单介绍win7下搭建apache+php+mysql开发环境
2015/08/06 PHP
完美解决phpdoc导出文档中@package的warning及Error的错误
2016/05/17 PHP
List all the Databases on a SQL Server
2007/06/21 Javascript
jValidate 基于jQuery的表单验证插件
2009/12/12 Javascript
提交表单时执行func方法实现代码
2013/03/17 Javascript
不要使用jQuery触发原生事件的方法
2014/03/03 Javascript
js实现数组冒泡排序、快速排序原理
2016/03/08 Javascript
移动端日期插件Mobiscroll.js使用详解
2016/12/19 Javascript
ES6字符串模板,剩余参数,默认参数功能与用法示例
2017/04/06 Javascript
nodejs入门教程二:创建一个简单应用示例
2017/04/24 NodeJs
Angular.js ng-file-upload结合springMVC的使用教程
2017/07/10 Javascript
微信小程序scroll-x失效的完美解决方法
2018/07/18 Javascript
Vue面试题及Vue知识点整理
2018/10/07 Javascript
JavaScript实现邮箱后缀提示功能的示例代码
2018/12/13 Javascript
小程序云开发获取不到数据库记录的解决方法
2019/05/18 Javascript
解决layui checkbox 提交多个值的问题
2019/09/02 Javascript
微信小程序封装多张图片上传api代码实例
2019/12/30 Javascript
swiper4实现移动端导航栏tab滑动切换
2020/10/16 Javascript
使用TS来编写express服务器的方法步骤
2020/10/29 Javascript
python 禁止函数修改列表的实现方法
2017/08/03 Python
Python实现按学生年龄排序的实际问题详解
2017/08/29 Python
在pycharm中使用git版本管理以及同步github的方法
2019/01/16 Python
Django中信号signals的简单使用方法
2019/07/04 Python
Python八皇后问题解答过程详解
2019/07/29 Python
关于python字符串方法分类详解
2019/08/20 Python
python如何判断IP地址合法性
2020/04/05 Python
Python实现在线批量美颜功能过程解析
2020/06/10 Python
Python调用OpenCV实现图像平滑代码实例
2020/06/19 Python
党支部综合考察材料
2014/05/19 职场文书
化验室安全管理制度
2015/08/06 职场文书
2016简单的租房合同范本
2016/03/18 职场文书
Python还能这么玩之只用30行代码从excel提取个人值班表
2021/06/05 Python