python机器学习之神经网络实现


Posted in Python onOctober 13, 2018

神经网络在机器学习中有很大的应用,甚至涉及到方方面面。本文主要是简单介绍一下神经网络的基本理论概念和推算。同时也会介绍一下神经网络在数据分类方面的应用。

首先,当我们建立一个回归和分类模型的时候,无论是用最小二乘法(OLS)还是最大似然值(MLE)都用来使得残差达到最小。因此我们在建立模型的时候,都会有一个loss function。

而在神经网络里也不例外,也有个类似的loss function。

对回归而言:

python机器学习之神经网络实现

对分类而言:

python机器学习之神经网络实现

然后同样方法,对于W开始求导,求导为零就可以求出极值来。

关于式子中的W。我们在这里以三层的神经网络为例。先介绍一下神经网络的相关参数。

python机器学习之神经网络实现

第一层是输入层,第二层是隐藏层,第三层是输出层。

在X1,X2经过W1的加权后,达到隐藏层,然后经过W2的加权,到达输出层

其中,

python机器学习之神经网络实现

我们有:

python机器学习之神经网络实现

至此,我们建立了一个初级的三层神经网络。

当我们要求其的loss function最小时,我们需要逆向来求,也就是所谓的backpropagation。

我们要分别对W1和W2进行求导,然后求出其极值。

从右手边开始逆推,首先对W2进行求导。

代入损失函数公式:

python机器学习之神经网络实现

python机器学习之神经网络实现

然后,我们进行化简:

python机器学习之神经网络实现

化简到这里,我们同理再对W1进行求导。

python机器学习之神经网络实现

我们可以发现当我们在做bp网络时候,有一个逆推回去的误差项,其决定了loss function 的最终大小。

在实际的运算当中,我们会用到梯度求解,来求出极值点。

python机器学习之神经网络实现

总结一下来说,我们使用向前推进来理顺神经网络做到回归分类等模型。而向后推进来计算他的损失函数,使得参数W有一个最优解。

当然,和线性回归等模型相类似的是,我们也可以加上正则化的项来对W参数进行约束,以免使得模型的偏差太小,而导致在测试集的表现不佳。

python机器学习之神经网络实现

python机器学习之神经网络实现

Python 的实现:

使用了KERAS的库

解决线性回归: 

model.add(Dense(1, input_dim=n_features, activation='linear', use_bias=True))

# Use mean squared error for the loss metric and use the ADAM backprop algorithm
model.compile(loss='mean_squared_error', optimizer='adam')

# Train the network (learn the weights)
# We need to convert from DataFrame to NumpyArray
history = model.fit(X_train.values, y_train.values, epochs=100, 
     batch_size=1, verbose=2, validation_split=0)

解决多重分类问题: 

# create model
model = Sequential()
model.add(Dense(64, activation='relu', input_dim=n_features))
model.add(Dropout(0.5))
model.add(Dense(64, activation='relu'))
model.add(Dropout(0.5))
# Softmax output layer
model.add(Dense(7, activation='softmax'))

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

model.fit(X_train.values, y_train.values, epochs=20, batch_size=16)

y_pred = model.predict(X_test.values)

y_te = np.argmax(y_test.values, axis = 1)
y_pr = np.argmax(y_pred, axis = 1)

print(np.unique(y_pr))

print(classification_report(y_te, y_pr))

print(confusion_matrix(y_te, y_pr))

当我们选取最优参数时候,有很多种解决的途径。这里就介绍一种是gridsearchcv的方法,这是一种暴力检索的方法,遍历所有的设定参数来求得最优参数。

from sklearn.model_selection import GridSearchCV

def create_model(optimizer='rmsprop'):
 model = Sequential()
 model.add(Dense(64, activation='relu', input_dim=n_features))
 model.add(Dropout(0.5))
 model.add(Dense(64, activation='relu'))
 model.add(Dropout(0.5))
 model.add(Dense(7, activation='softmax'))
 model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])
 
 return model

model = KerasClassifier(build_fn=create_model, verbose=0)

optimizers = ['rmsprop']
epochs = [5, 10, 15]
batches = [128]


param_grid = dict(optimizer=optimizers, epochs=epochs, batch_size=batches, verbose=['2'])
grid = GridSearchCV(estimator=model, param_grid=param_grid)

grid.fit(X_train.values, y_train.values)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python转换HTML到Text纯文本的方法
Jan 15 Python
Python实现删除Android工程中的冗余字符串
Jan 19 Python
Python获取DLL和EXE文件版本号的方法
Mar 10 Python
python实现自动解数独小程序
Jan 21 Python
Python numpy中矩阵的基本用法汇总
Feb 12 Python
Python 限制线程的最大数量的方法(Semaphore)
Feb 22 Python
Python进阶之@property动态属性的实现
Apr 01 Python
Python中那些 Pythonic的写法详解
Jul 02 Python
基于TensorFlow中自定义梯度的2种方式
Feb 04 Python
在pycharm中使用matplotlib.pyplot 绘图时报错的解决
Jun 01 Python
python实现凯撒密码、凯撒加解密算法
Jun 11 Python
安装python依赖包psycopg2来调用postgresql的操作
Jan 01 Python
Python pyinotify模块实现对文档的实时监控功能方法
Oct 13 #Python
基于pycharm导入模块显示不存在的解决方法
Oct 13 #Python
解决PyCharm import torch包失败的问题
Oct 13 #Python
python3+requests接口自动化session操作方法
Oct 13 #Python
解决pycharm无法识别本地site-packages的问题
Oct 13 #Python
解决PyCharm同目录下导入模块会报错的问题
Oct 13 #Python
python中单例常用的几种实现方法总结
Oct 13 #Python
You might like
德生PL660的电路分析和打磨
2021/03/02 无线电
一步一步学习PHP(3) php 函数
2010/02/15 PHP
让PHP开发者事半功倍的十大技巧小结
2010/04/20 PHP
windows的文件系统机制引发的PHP路径爆破问题分析
2014/07/28 PHP
php中socket通信机制实例详解
2015/01/03 PHP
CI框架中site_url()和base_url()的区别
2015/01/07 PHP
PHP也能干大事 随机函数
2015/04/14 PHP
php简单解析mysqli查询结果的方法(2种方法)
2016/06/29 PHP
laravel5.1框架基础之Blade模板继承简单使用方法分析
2019/09/05 PHP
JavaScript中的Location地址对象
2008/01/16 Javascript
jQuery中DOM树操作之复制元素的方法
2015/01/23 Javascript
Javascript获取统一管理的提示语(message)
2016/02/03 Javascript
JavaScript实现简单的星星评分效果
2017/05/18 Javascript
利用yarn代替npm管理前端项目模块依赖的方法详解
2017/09/04 Javascript
jQuery中库的引用方法
2018/01/06 jQuery
vue watch普通监听和深度监听实例详解(数组和对象)
2018/08/16 Javascript
使用 JavaScript 创建并下载文件(模拟点击)
2019/10/25 Javascript
Element Carousel 走马灯的具体实现
2020/07/26 Javascript
Python常用的日期时间处理方法示例
2015/02/08 Python
Python第三方库xlrd/xlwt的安装与读写Excel表格
2017/01/21 Python
Python实现PS图像调整之对比度调整功能示例
2018/01/26 Python
Python面向对象之继承代码详解
2018/01/29 Python
PyCharm设置SSH远程调试的方法
2018/07/17 Python
Python实现银行账户资金交易管理系统
2020/01/03 Python
Python通过4种方式实现进程数据通信
2020/03/12 Python
MoviePy常用剪辑类及Python视频剪辑自动化
2020/12/18 Python
iframe跨域的几种常用方法
2019/11/11 HTML / CSS
如何保障Web服务器安全
2014/05/05 面试题
经典婚礼主持开场白
2014/03/13 职场文书
幼儿园运动会口号
2014/06/07 职场文书
语文教育专业求职信
2014/06/28 职场文书
放飞梦想演讲稿600字
2014/08/26 职场文书
财政局长个人总结
2015/03/04 职场文书
地震捐款简报
2015/07/21 职场文书
网吧员工管理制度
2015/08/05 职场文书
Nginx解决403 forbidden的完整步骤
2021/04/01 Servers