如何用Python 实现全连接神经网络(Multi-layer Perceptron)


Posted in Python onOctober 15, 2020

代码

import numpy as np

# 各种激活函数及导数
def sigmoid(x):
  return 1 / (1 + np.exp(-x))


def dsigmoid(y):
  return y * (1 - y)


def tanh(x):
  return np.tanh(x)


def dtanh(y):
  return 1.0 - y ** 2


def relu(y):
  tmp = y.copy()
  tmp[tmp < 0] = 0
  return tmp


def drelu(x):
  tmp = x.copy()
  tmp[tmp >= 0] = 1
  tmp[tmp < 0] = 0
  return tmp


class MLPClassifier(object):
  """多层感知机,BP 算法训练"""

  def __init__(self,
         layers,
         activation='tanh',
         epochs=20, batch_size=1, learning_rate=0.01):
    """
    :param layers: 网络层结构
    :param activation: 激活函数
    :param epochs: 迭代轮次
    :param learning_rate: 学习率 
    """
    self.epochs = epochs
    self.learning_rate = learning_rate
    self.layers = []
    self.weights = []
    self.batch_size = batch_size

    for i in range(0, len(layers) - 1):
      weight = np.random.random((layers[i], layers[i + 1]))
      layer = np.ones(layers[i])
      self.layers.append(layer)
      self.weights.append(weight)
    self.layers.append(np.ones(layers[-1]))

    self.thresholds = []
    for i in range(1, len(layers)):
      threshold = np.random.random(layers[i])
      self.thresholds.append(threshold)

    if activation == 'tanh':
      self.activation = tanh
      self.dactivation = dtanh
    elif activation == 'sigomid':
      self.activation = sigmoid
      self.dactivation = dsigmoid
    elif activation == 'relu':
      self.activation = relu
      self.dactivation = drelu

  def fit(self, X, y):
    """
    :param X_: shape = [n_samples, n_features] 
    :param y: shape = [n_samples] 
    :return: self
    """
    for _ in range(self.epochs * (X.shape[0] // self.batch_size)):
      i = np.random.choice(X.shape[0], self.batch_size)
      # i = np.random.randint(X.shape[0])
      self.update(X[i])
      self.back_propagate(y[i])

  def predict(self, X):
    """
    :param X: shape = [n_samples, n_features] 
    :return: shape = [n_samples]
    """
    self.update(X)
    return self.layers[-1].copy()

  def update(self, inputs):
    self.layers[0] = inputs
    for i in range(len(self.weights)):
      next_layer_in = self.layers[i] @ self.weights[i] - self.thresholds[i]
      self.layers[i + 1] = self.activation(next_layer_in)

  def back_propagate(self, y):
    errors = y - self.layers[-1]

    gradients = [(self.dactivation(self.layers[-1]) * errors).sum(axis=0)]

    self.thresholds[-1] -= self.learning_rate * gradients[-1]
    for i in range(len(self.weights) - 1, 0, -1):
      tmp = np.sum(gradients[-1] @ self.weights[i].T * self.dactivation(self.layers[i]), axis=0)
      gradients.append(tmp)
      self.thresholds[i - 1] -= self.learning_rate * gradients[-1] / self.batch_size
    gradients.reverse()
    for i in range(len(self.weights)):
      tmp = np.mean(self.layers[i], axis=0)
      self.weights[i] += self.learning_rate * tmp.reshape((-1, 1)) * gradients[i]

测试代码

import sklearn.datasets
import numpy as np

def plot_decision_boundary(pred_func, X, y, title=None):
  """分类器画图函数,可画出样本点和决策边界
  :param pred_func: predict函数
  :param X: 训练集X
  :param y: 训练集Y
  :return: None
  """

  # Set min and max values and give it some padding
  x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5
  y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5
  h = 0.01
  # Generate a grid of points with distance h between them
  xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
  # Predict the function value for the whole gid
  Z = pred_func(np.c_[xx.ravel(), yy.ravel()])
  Z = Z.reshape(xx.shape)
  # Plot the contour and training examples
  plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral)
  plt.scatter(X[:, 0], X[:, 1], s=40, c=y, cmap=plt.cm.Spectral)

  if title:
    plt.title(title)
  plt.show()


def test_mlp():
  X, y = sklearn.datasets.make_moons(200, noise=0.20)
  y = y.reshape((-1, 1))
  n = MLPClassifier((2, 3, 1), activation='tanh', epochs=300, learning_rate=0.01)
  n.fit(X, y)
  def tmp(X):
    sign = np.vectorize(lambda x: 1 if x >= 0.5 else 0)
    ans = sign(n.predict(X))
    return ans

  plot_decision_boundary(tmp, X, y, 'Neural Network')

效果

如何用Python 实现全连接神经网络(Multi-layer Perceptron)

如何用Python 实现全连接神经网络(Multi-layer Perceptron)

更多机器学习代码,请访问 https://github.com/WiseDoge/plume

以上就是如何用Python 实现全连接神经网络(Multi-layer Perceptron)的详细内容,更多关于Python 实现全连接神经网络的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
Python对列表排序的方法实例分析
May 16 Python
python中zip和unzip数据的方法
May 27 Python
python中偏函数partial用法实例分析
Jul 08 Python
详解Django框架中用户的登录和退出的实现
Jul 23 Python
python+matplotlib绘制3D条形图实例代码
Jan 17 Python
Python查看微信撤回消息代码
Jun 07 Python
Django中的Model操作表的实现
Jul 24 Python
python实现反转部分单向链表
Sep 27 Python
python实现简易动态时钟
Nov 19 Python
Python lambda表达式filter、map、reduce函数用法解析
Sep 11 Python
python实现高斯投影正反算方式
Jan 17 Python
Python jieba结巴分词原理及用法解析
Nov 05 Python
python 实现非极大值抑制算法(Non-maximum suppression, NMS)
Oct 15 #Python
解决pip安装的第三方包在PyCharm无法导入的问题
Oct 15 #Python
python实现粒子群算法
Oct 15 #Python
如何将anaconda安装配置的mmdetection环境离线拷贝到另一台电脑
Oct 15 #Python
Python3.7安装PyQt5 运行配置Pycharm的详细教程
Oct 15 #Python
python利用faker库批量生成测试数据
Oct 15 #Python
如何利用python检测图片是否包含二维码
Oct 15 #Python
You might like
PHP 页面跳转到另一个页面的多种方法方法总结
2009/07/07 PHP
php 自写函数代码 获取关键字 去超链接
2010/02/08 PHP
php连接mssql数据库的几种方法
2013/02/21 PHP
解析zend Framework如何自动加载类
2013/06/28 PHP
php邮件发送的两种方式
2020/04/28 PHP
Yii2使用$this-&gt;context获取当前的Module、Controller(控制器)、Action等
2017/03/29 PHP
js控制表单不能输入空格的小例子
2013/11/20 Javascript
jquery 无限级下拉菜单的简单实现代码
2014/02/21 Javascript
js中window.open打开一个新的页面
2014/08/10 Javascript
浅谈Sublime Text 3运行JavaScript控制台
2016/06/06 Javascript
深入理解JavaScript函数参数(推荐)
2016/07/26 Javascript
Backbone中View之间传值的学习心得
2016/08/09 Javascript
AngularJS入门教程之REST和定制服务详解
2016/08/19 Javascript
Sequelize中用group by进行分组聚合查询
2016/12/12 Javascript
解决Window10系统下Node安装报错的问题分析
2016/12/13 Javascript
基于vue的下拉刷新指令和滚动刷新指令
2016/12/23 Javascript
js+div+css下拉导航菜单完整代码分享
2016/12/28 Javascript
详解使用Vue.Js结合Jquery Ajax加载数据的两种方式
2017/01/10 Javascript
JavaScript ES6中export、import与export default的用法和区别
2017/03/14 Javascript
详解vue-router 2.0 常用基础知识点之router-link
2017/05/10 Javascript
微信页面弹出键盘后iframe内容变空白的解决方案
2017/09/20 Javascript
Vue项目使用CDN优化首屏加载问题
2018/04/01 Javascript
详解js常用分割取字符串的方法
2019/05/15 Javascript
webpack + vue 打包生成公共配置文件(域名) 方便动态修改
2019/08/29 Javascript
python3中的md5加密实例
2018/05/29 Python
详解Python发送email的三种方式
2018/10/18 Python
Python二叉搜索树与双向链表转换算法示例
2019/03/02 Python
在Python中使用filter去除列表中值为假及空字符串的例子
2019/11/18 Python
Linux Interview Questions For software testers
2012/06/02 面试题
Set里的元素是不能重复的,那么用什么方法来区分重复与否呢?
2016/08/18 面试题
医学生自荐信
2013/12/03 职场文书
仓管岗位职责范本
2014/02/08 职场文书
机关干部四风问题自查报告及整改措施
2014/10/26 职场文书
2016年清明节期间群众祭祀活动工作总结
2016/04/01 职场文书
导游词之新疆尼雅遗址
2019/10/16 职场文书
竞选稿之小学班干部
2019/10/31 职场文书