神经网络(BP)算法Python实现及应用


Posted in Python onApril 16, 2018

本文实例为大家分享了Python实现神经网络算法及应用的具体代码,供大家参考,具体内容如下

首先用Python实现简单地神经网络算法:

import numpy as np


# 定义tanh函数
def tanh(x):
  return np.tanh(x)


# tanh函数的导数
def tan_deriv(x):
  return 1.0 - np.tanh(x) * np.tan(x)


# sigmoid函数
def logistic(x):
  return 1 / (1 + np.exp(-x))


# sigmoid函数的导数
def logistic_derivative(x):
  return logistic(x) * (1 - logistic(x))


class NeuralNetwork:
  def __init__(self, layers, activation='tanh'):
    """
    神经网络算法构造函数
    :param layers: 神经元层数
    :param activation: 使用的函数(默认tanh函数)
    :return:none
    """
    if activation == 'logistic':
      self.activation = logistic
      self.activation_deriv = logistic_derivative
    elif activation == 'tanh':
      self.activation = tanh
      self.activation_deriv = tan_deriv

    # 权重列表
    self.weights = []
    # 初始化权重(随机)
    for i in range(1, len(layers) - 1):
      self.weights.append((2 * np.random.random((layers[i - 1] + 1, layers[i] + 1)) - 1) * 0.25)
      self.weights.append((2 * np.random.random((layers[i] + 1, layers[i + 1])) - 1) * 0.25)

  def fit(self, X, y, learning_rate=0.2, epochs=10000):
    """
    训练神经网络
    :param X: 数据集(通常是二维)
    :param y: 分类标记
    :param learning_rate: 学习率(默认0.2)
    :param epochs: 训练次数(最大循环次数,默认10000)
    :return: none
    """
    # 确保数据集是二维的
    X = np.atleast_2d(X)

    temp = np.ones([X.shape[0], X.shape[1] + 1])
    temp[:, 0: -1] = X
    X = temp
    y = np.array(y)

    for k in range(epochs):
      # 随机抽取X的一行
      i = np.random.randint(X.shape[0])
      # 用随机抽取的这一组数据对神经网络更新
      a = [X[i]]
      # 正向更新
      for l in range(len(self.weights)):
        a.append(self.activation(np.dot(a[l], self.weights[l])))
      error = y[i] - a[-1]
      deltas = [error * self.activation_deriv(a[-1])]

      # 反向更新
      for l in range(len(a) - 2, 0, -1):
        deltas.append(deltas[-1].dot(self.weights[l].T) * self.activation_deriv(a[l]))
        deltas.reverse()
      for i in range(len(self.weights)):
        layer = np.atleast_2d(a[i])
        delta = np.atleast_2d(deltas[i])
        self.weights[i] += learning_rate * layer.T.dot(delta)

  def predict(self, x):
    x = np.array(x)
    temp = np.ones(x.shape[0] + 1)
    temp[0:-1] = x
    a = temp
    for l in range(0, len(self.weights)):
      a = self.activation(np.dot(a, self.weights[l]))
    return a

使用自己定义的神经网络算法实现一些简单的功能:

 小案例:

X:                  Y
0 0                 0
0 1                 1
1 0                 1
1 1                 0

from NN.NeuralNetwork import NeuralNetwork
import numpy as np

nn = NeuralNetwork([2, 2, 1], 'tanh')
temp = [[0, 0], [0, 1], [1, 0], [1, 1]]
X = np.array(temp)
y = np.array([0, 1, 1, 0])
nn.fit(X, y)
for i in temp:
  print(i, nn.predict(i))

神经网络(BP)算法Python实现及应用

发现结果基本机制,无限接近0或者无限接近1 

第二个例子:识别图片中的数字

导入数据:

from sklearn.datasets import load_digits
import pylab as pl

digits = load_digits()
print(digits.data.shape)
pl.gray()
pl.matshow(digits.images[0])
pl.show()

观察下:大小:(1797, 64)

数字0

神经网络(BP)算法Python实现及应用

接下来的代码是识别它们:

import numpy as np
from sklearn.datasets import load_digits
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.preprocessing import LabelBinarizer
from NN.NeuralNetwork import NeuralNetwork
from sklearn.cross_validation import train_test_split

# 加载数据集
digits = load_digits()
X = digits.data
y = digits.target
# 处理数据,使得数据处于0,1之间,满足神经网络算法的要求
X -= X.min()
X /= X.max()

# 层数:
# 输出层10个数字
# 输入层64因为图片是8*8的,64像素
# 隐藏层假设100
nn = NeuralNetwork([64, 100, 10], 'logistic')
# 分隔训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y)

# 转化成sklearn需要的二维数据类型
labels_train = LabelBinarizer().fit_transform(y_train)
labels_test = LabelBinarizer().fit_transform(y_test)
print("start fitting")
# 训练3000次
nn.fit(X_train, labels_train, epochs=3000)
predictions = []
for i in range(X_test.shape[0]):
  o = nn.predict(X_test[i])
  # np.argmax:第几个数对应最大概率值
  predictions.append(np.argmax(o))

# 打印预测相关信息
print(confusion_matrix(y_test, predictions))
print(classification_report(y_test, predictions))

结果:

矩阵对角线代表预测正确的数量,发现正确率很多

神经网络(BP)算法Python实现及应用

这张表更直观地显示出预测正确率:

共450个案例,成功率94%

神经网络(BP)算法Python实现及应用

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python递归计算N!的方法
May 05 Python
Python的pycurl包用法简介
Nov 13 Python
python编码总结(编码类型、格式、转码)
Jul 01 Python
离线安装Pyecharts的步骤以及依赖包流程
Apr 23 Python
Python 中迭代器与生成器实例详解
Mar 29 Python
DataFrame 将某列数据转为数组的方法
Apr 13 Python
Python3之简单搭建自带服务器的实例讲解
Jun 04 Python
python实现简单日期工具类
Apr 24 Python
Python实现Wordcloud生成词云图的示例
Mar 30 Python
Python 线性回归分析以及评价指标详解
Apr 02 Python
解决python ThreadPoolExecutor 线程池中的异常捕获问题
Apr 08 Python
Django搭建项目实战与避坑细节详解
Dec 06 Python
python读取视频流提取视频帧的两种方法
Oct 22 #Python
python读取和保存视频文件
Apr 16 #Python
Python读取视频的两种方法(imageio和cv2)
Apr 15 #Python
python2.7实现FTP文件下载功能
Apr 15 #Python
python实现多线程网页下载器
Apr 15 #Python
Python实现定时精度可调节的定时器
Apr 15 #Python
Python编写一个优美的下载器
Apr 15 #Python
You might like
基于wordpress主题制作的具体实现步骤
2013/05/10 PHP
codeigniter框架批量插入数据
2014/01/09 PHP
php限制上传文件类型并保存上传文件的方法
2015/03/13 PHP
PHP中iconv函数知识汇总
2015/07/02 PHP
PHP 数组基本操作方法详解
2016/06/17 PHP
Mootools 1.2教程 函数
2009/09/15 Javascript
在IE和VB中支持png图片透明效果的实现方法(vb源码打包)
2011/04/01 Javascript
javascript在子页面中函数无法调试问题解决方法
2014/01/17 Javascript
jquery五角星评分插件示例分享
2014/02/21 Javascript
JavaScript字符串对象charAt方法入门实例(用于取得指定位置的字符)
2014/10/17 Javascript
浅谈 javascript 事件处理
2015/01/04 Javascript
JavaScript控制两个列表框listbox左右交换数据的方法
2015/03/18 Javascript
WEB前端开发都应知道的jquery小技巧及jquery三个简写
2015/11/15 Javascript
jQuery的中 is(':visible') 解析及用法(必看)
2017/02/12 Javascript
JavaScript 完成注册页面表单校验的实例
2017/08/19 Javascript
原生JS实现的碰撞检测功能示例
2018/05/18 Javascript
详解微信小程序框架wepy踩坑记录(与vue对比)
2019/03/12 Javascript
jquery.pager.js实现分页效果
2019/07/29 jQuery
vue在路由中验证token是否存在的简单实现
2019/11/11 Javascript
js实现简单的随机点名器
2020/09/17 Javascript
关于Vue中$refs的探索浅析
2020/11/05 Javascript
[53:43]VP vs NewBee Supermajor 胜者组 BO3 第三场 6.5
2018/06/06 DOTA
以Flask为例讲解Python的框架的使用方法
2015/04/29 Python
Python实现Sqlite将字段当做索引进行查询的方法
2016/07/21 Python
django开发之settings.py中变量的全局引用详解
2017/03/29 Python
Python排序搜索基本算法之堆排序实例详解
2017/12/08 Python
python3+selenium获取页面加载的所有静态资源文件链接操作
2020/05/04 Python
Python接口自动化测试框架运行原理及流程
2020/11/30 Python
打架检讨书500字
2014/01/29 职场文书
《春到梅花山》教学反思
2014/04/16 职场文书
服务行业演讲稿
2014/09/02 职场文书
查摆问题自查报告范文
2014/10/13 职场文书
2015年员工工作总结范文
2015/04/08 职场文书
2016春季运动会通讯稿
2015/07/18 职场文书
Java+swing实现抖音上的表白程序详解
2022/06/25 Java/Android
Python find()、rfind()方法及作用
2022/12/24 Python