numpy实现RNN原理实现


Posted in Python onMarch 02, 2021

首先说明代码只是帮助理解,并未写出梯度下降部分,默认参数已经被固定,不影响理解。代码主要实现RNN原理,只使用numpy库,不可用于GPU加速。

import numpy as np


class Rnn():

  def __init__(self, input_size, hidden_size, num_layers, bidirectional=False):
    self.input_size = input_size
    self.hidden_size = hidden_size
    self.num_layers = num_layers
    self.bidirectional = bidirectional

  def feed(self, x):
    '''

    :param x: [seq, batch_size, embedding]
    :return: out, hidden
    '''

    # x.shape [sep, batch, feature]
    # hidden.shape [hidden_size, batch]
    # Whh0.shape [hidden_size, hidden_size] Wih0.shape [hidden_size, feature]
    # Whh1.shape [hidden_size, hidden_size] Wih1.size [hidden_size, hidden_size]

    out = []
    x, hidden = np.array(x), [np.zeros((self.hidden_size, x.shape[1])) for i in range(self.num_layers)]
    Wih = [np.random.random((self.hidden_size, self.hidden_size)) for i in range(1, self.num_layers)]
    Wih.insert(0, np.random.random((self.hidden_size, x.shape[2])))
    Whh = [np.random.random((self.hidden_size, self.hidden_size)) for i in range(self.num_layers)]

    time = x.shape[0]
    for i in range(time):
      hidden[0] = np.tanh((np.dot(Wih[0], np.transpose(x[i, ...], (1, 0))) +
               np.dot(Whh[0], hidden[0])
               ))

      for i in range(1, self.num_layers):
        hidden[i] = np.tanh((np.dot(Wih[i], hidden[i-1]) +
                   np.dot(Whh[i], hidden[i])
                   ))

      out.append(hidden[self.num_layers-1])

    return np.array(out), np.array(hidden)


def sigmoid(x):
  return 1.0/(1.0 + 1.0/np.exp(x))


if __name__ == '__main__':
  rnn = Rnn(1, 5, 4)
  input = np.random.random((6, 2, 1))
  out, h = rnn.feed(input)
  print(f'seq is {input.shape[0]}, batch_size is {input.shape[1]} ', 'out.shape ', out.shape, ' h.shape ', h.shape)
  # print(sigmoid(np.random.random((2, 3))))
  #
  # element-wise multiplication
  # print(np.array([1, 2])*np.array([2, 1]))

到此这篇关于numpy实现RNN原理实现的文章就介绍到这了,更多相关numpy实现RNN内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python多线程实例教程
Sep 06 Python
python实现的AES双向对称加密解密与用法分析
May 02 Python
Python中join函数简单代码示例
Jan 09 Python
python3.6的venv模块使用详解
Aug 01 Python
Python3基础教程之递归函数简单示例
Jun 07 Python
Python字符串对象实现原理详解
Jul 01 Python
django框架使用views.py的函数对表进行增删改查内容操作详解【models.py中表的创建、views.py中函数的使用,基于对象的跨表查询】
Dec 12 Python
python GUI库图形界面开发之PyQt5简单绘图板实例与代码分析
Mar 08 Python
Python实现疫情通定时自动填写功能(附代码)
May 27 Python
PyQt5实现登录页面
May 30 Python
OpenCV利用python来实现图像的直方图均衡化
Oct 21 Python
python unichr函数知识点总结
Dec 16 Python
解决tensorflow模型压缩的问题_踩坑无数,总算搞定
Mar 02 #Python
python Protobuf定义消息类型知识点讲解
Mar 02 #Python
Django项目在pycharm新建的步骤方法
Mar 02 #Python
基于注解实现 SpringBoot 接口防刷的方法
Mar 02 #Python
python Autopep8实现按PEP8风格自动排版Python代码
Mar 02 #Python
pycharm配置安装autopep8自动规范代码的实现
Mar 02 #Python
Python实现我的世界小游戏源代码
Mar 02 #Python
You might like
浅析php学习的路线图
2013/07/10 PHP
php function用法如何递归及return和echo区别
2014/03/07 PHP
PHP进程同步代码实例
2015/02/12 PHP
Laravel 不同生产环境服务器的判断实践
2019/10/15 PHP
动态加载iframe
2006/06/16 Javascript
JavaScript 打地鼠游戏代码说明
2010/10/12 Javascript
javascript从右边截取指定字符串的三种实现方法
2013/11/29 Javascript
基于node实现websocket协议
2016/04/25 Javascript
js实现textarea限制输入字数
2017/02/13 Javascript
详解weex默认webpack.config.js改造
2018/01/08 Javascript
React组件refs的使用详解
2018/02/09 Javascript
ng-repeat指令在迭代对象时的去重方法
2018/10/02 Javascript
基于three.js实现的3D粒子动效实例代码
2019/04/09 Javascript
Easyui 去除jquery-easui tab页div自带滚动条的方法
2019/05/10 jQuery
elementUI vue this.$confirm 和el-dialog 弹出框 移动 示例demo
2019/07/03 Javascript
layui数据表格跨行自动合并的例子
2019/09/02 Javascript
Python实现提取文章摘要的方法
2015/04/21 Python
Python  pip安装lxml出错的问题解决办法
2017/02/10 Python
深入学习Python中的上下文管理器与else块
2017/08/27 Python
python使用pil库实现图片合成实例代码
2018/01/20 Python
Python3.6日志Logging模块简单用法示例
2018/06/14 Python
PyTorch 中的傅里叶卷积实现示例
2020/12/11 Python
Mansur Gavriel官网:纽约市的一个设计品牌
2019/05/02 全球购物
Kipling意大利官网:世界著名的时尚休闲包袋品牌
2019/06/05 全球购物
英国豪华装饰照明品牌的在线零售商:Inspyer Lighting
2019/12/10 全球购物
EJB与JAVA BEAN的区别
2016/08/29 面试题
售后服务经理岗位职责范本
2014/02/22 职场文书
财务会计大学生自我评价
2014/04/09 职场文书
英文请假条
2014/04/11 职场文书
保密工作责任书
2014/04/16 职场文书
班级标语大全
2014/06/21 职场文书
普宁寺导游词
2015/02/04 职场文书
朋友圈早安励志语录!
2019/07/08 职场文书
MySQL系列之八 MySQL服务器变量
2021/07/02 MySQL
Python必备技巧之函数的使用详解
2022/04/04 Python
多台电脑共享文件怎么设置?多台电脑共享文件操作教程
2022/04/08 数码科技