tensorflow实现残差网络方式(mnist数据集)


Posted in Python onMay 26, 2020

介绍

残差网络是何凯明大神的神作,效果非常好,深度可以达到1000层。但是,其实现起来并没有那末难,在这里以tensorflow作为框架,实现基于mnist数据集上的残差网络,当然只是比较浅层的。

如下图所示:

tensorflow实现残差网络方式(mnist数据集)

实线的Connection部分,表示通道相同,如上图的第一个粉色矩形和第三个粉色矩形,都是3x3x64的特征图,由于通道相同,所以采用计算方式为H(x)=F(x)+x

虚线的的Connection部分,表示通道不同,如上图的第一个绿色矩形和第三个绿色矩形,分别是3x3x64和3x3x128的特征图,通道不同,采用的计算方式为H(x)=F(x)+Wx,其中W是卷积操作,用来调整x维度的。

根据输入和输出尺寸是否相同,又分为identity_block和conv_block,每种block有上图两种模式,三卷积和二卷积,三卷积速度更快些,因此在这里选择该种方式。

具体实现见如下代码:

#tensorflow基于mnist数据集上的VGG11网络,可以直接运行
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
#tensorflow基于mnist实现VGG11
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

#x=mnist.train.images
#y=mnist.train.labels
#X=mnist.test.images
#Y=mnist.test.labels
x = tf.placeholder(tf.float32, [None,784])
y = tf.placeholder(tf.float32, [None, 10])
sess = tf.InteractiveSession()

def weight_variable(shape):
#这里是构建初始变量
 initial = tf.truncated_normal(shape, mean=0,stddev=0.1)
#创建变量
 return tf.Variable(initial)

def bias_variable(shape):
 initial = tf.constant(0.1, shape=shape)
 return tf.Variable(initial)

#在这里定义残差网络的id_block块,此时输入和输出维度相同
def identity_block(X_input, kernel_size, in_filter, out_filters, stage, block):
 """
 Implementation of the identity block as defined in Figure 3

 Arguments:
 X -- input tensor of shape (m, n_H_prev, n_W_prev, n_C_prev)
 kernel_size -- integer, specifying the shape of the middle CONV's window for the main path
 filters -- python list of integers, defining the number of filters in the CONV layers of the main path
 stage -- integer, used to name the layers, depending on their position in the network
 block -- string/character, used to name the layers, depending on their position in the network
 training -- train or test

 Returns:
 X -- output of the identity block, tensor of shape (n_H, n_W, n_C)
 """

 # defining name basis
 block_name = 'res' + str(stage) + block
 f1, f2, f3 = out_filters
 with tf.variable_scope(block_name):
  X_shortcut = X_input

  #first
  W_conv1 = weight_variable([1, 1, in_filter, f1])
  X = tf.nn.conv2d(X_input, W_conv1, strides=[1, 1, 1, 1], padding='SAME')
  b_conv1 = bias_variable([f1])
  X = tf.nn.relu(X+ b_conv1)

  #second
  W_conv2 = weight_variable([kernel_size, kernel_size, f1, f2])
  X = tf.nn.conv2d(X, W_conv2, strides=[1, 1, 1, 1], padding='SAME')
  b_conv2 = bias_variable([f2])
  X = tf.nn.relu(X+ b_conv2)

  #third

  W_conv3 = weight_variable([1, 1, f2, f3])
  X = tf.nn.conv2d(X, W_conv3, strides=[1, 1, 1, 1], padding='SAME')
  b_conv3 = bias_variable([f3])
  X = tf.nn.relu(X+ b_conv3)
  #final step
  add = tf.add(X, X_shortcut)
  b_conv_fin = bias_variable([f3])
  add_result = tf.nn.relu(add+b_conv_fin)

 return add_result


#这里定义conv_block模块,由于该模块定义时输入和输出尺度不同,故需要进行卷积操作来改变尺度,从而得以相加
def convolutional_block( X_input, kernel_size, in_filter,
    out_filters, stage, block, stride=2):
 """
 Implementation of the convolutional block as defined in Figure 4

 Arguments:
 X -- input tensor of shape (m, n_H_prev, n_W_prev, n_C_prev)
 kernel_size -- integer, specifying the shape of the middle CONV's window for the main path
 filters -- python list of integers, defining the number of filters in the CONV layers of the main path
 stage -- integer, used to name the layers, depending on their position in the network
 block -- string/character, used to name the layers, depending on their position in the network
 training -- train or test
 stride -- Integer, specifying the stride to be used

 Returns:
 X -- output of the convolutional block, tensor of shape (n_H, n_W, n_C)
 """

 # defining name basis
 block_name = 'res' + str(stage) + block
 with tf.variable_scope(block_name):
  f1, f2, f3 = out_filters

  x_shortcut = X_input
  #first
  W_conv1 = weight_variable([1, 1, in_filter, f1])
  X = tf.nn.conv2d(X_input, W_conv1,strides=[1, stride, stride, 1],padding='SAME')
  b_conv1 = bias_variable([f1])
  X = tf.nn.relu(X + b_conv1)

  #second
  W_conv2 =weight_variable([kernel_size, kernel_size, f1, f2])
  X = tf.nn.conv2d(X, W_conv2, strides=[1,1,1,1], padding='SAME')
  b_conv2 = bias_variable([f2])
  X = tf.nn.relu(X+b_conv2)

  #third
  W_conv3 = weight_variable([1,1, f2,f3])
  X = tf.nn.conv2d(X, W_conv3, strides=[1, 1, 1,1], padding='SAME')
  b_conv3 = bias_variable([f3])
  X = tf.nn.relu(X+b_conv3)
  #shortcut path
  W_shortcut =weight_variable([1, 1, in_filter, f3])
  x_shortcut = tf.nn.conv2d(x_shortcut, W_shortcut, strides=[1, stride, stride, 1], padding='VALID')

  #final
  add = tf.add(x_shortcut, X)
  #建立最后融合的权重
  b_conv_fin = bias_variable([f3])
  add_result = tf.nn.relu(add+ b_conv_fin)


 return add_result



x = tf.reshape(x, [-1,28,28,1])
w_conv1 = weight_variable([2, 2, 1, 64])
x = tf.nn.conv2d(x, w_conv1, strides=[1, 2, 2, 1], padding='SAME')
b_conv1 = bias_variable([64])
x = tf.nn.relu(x+b_conv1)
#这里操作后变成14x14x64
x = tf.nn.max_pool(x, ksize=[1, 3, 3, 1],
    strides=[1, 1, 1, 1], padding='SAME')


#stage 2
x = convolutional_block(X_input=x, kernel_size=3, in_filter=64, out_filters=[64, 64, 256], stage=2, block='a', stride=1)
#上述conv_block操作后,尺寸变为14x14x256
x = identity_block(x, 3, 256, [64, 64, 256], stage=2, block='b' )
x = identity_block(x, 3, 256, [64, 64, 256], stage=2, block='c')
#上述操作后张量尺寸变成14x14x256
x = tf.nn.max_pool(x, [1, 2, 2, 1], strides=[1,2,2,1], padding='SAME')
#变成7x7x256
flat = tf.reshape(x, [-1,7*7*256])

w_fc1 = weight_variable([7 * 7 *256, 1024])
b_fc1 = bias_variable([1024])

h_fc1 = tf.nn.relu(tf.matmul(flat, w_fc1) + b_fc1)
keep_prob = tf.placeholder(tf.float32)
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)
w_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])
y_conv = tf.matmul(h_fc1_drop, w_fc2) + b_fc2


#建立损失函数,在这里采用交叉熵函数
cross_entropy = tf.reduce_mean(
 tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=y_conv))

train_step = tf.train.AdamOptimizer(1e-3).minimize(cross_entropy)
correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
#初始化变量

sess.run(tf.global_variables_initializer())

print("cuiwei")
for i in range(2000):
 batch = mnist.train.next_batch(10)
 if i%100 == 0:
 train_accuracy = accuracy.eval(feed_dict={
 x:batch[0], y: batch[1], keep_prob: 1.0})
 print("step %d, training accuracy %g"%(i, train_accuracy))
 train_step.run(feed_dict={x: batch[0], y: batch[1], keep_prob: 0.5})

以上这篇tensorflow实现残差网络方式(mnist数据集)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python过滤函数filter()使用自定义函数过滤序列实例
Aug 26 Python
python使用线程封装的一个简单定时器类实例
May 16 Python
浅谈Python中的作用域规则和闭包
Mar 20 Python
详谈python3中用for循环删除列表中元素的坑
Apr 19 Python
详解Python的数据库操作(pymysql)
Apr 04 Python
python使用wxpy实现微信消息防撤回脚本
Apr 29 Python
python实现统计代码行数的小工具
Sep 19 Python
python求绝对值的三种方法小结
Dec 04 Python
jupyter notebook 实现matplotlib图动态刷新
Apr 22 Python
python操作redis数据库的三种方法
Sep 10 Python
opencv 分类白天与夜景视频的方法
Jun 05 Python
Python可变与不可变数据和深拷贝与浅拷贝
Apr 06 Python
Python中格式化字符串的四种实现
May 26 #Python
使用tensorflow实现VGG网络,训练mnist数据集方式
May 26 #Python
浅谈Tensorflow加载Vgg预训练模型的几个注意事项
May 26 #Python
Tensorflow加载Vgg预训练模型操作
May 26 #Python
PyQt5如何将.ui文件转换为.py文件的实例代码
May 26 #Python
TensorFlow实现模型断点训练,checkpoint模型载入方式
May 26 #Python
python 日志模块 日志等级设置失效的解决方案
May 26 #Python
You might like
Cappuccino 卡布其诺咖啡之制作
2021/03/03 冲泡冲煮
php5 apache 2.2 webservice 创建与配置(java)
2011/01/27 PHP
Laravel 5框架学习之用户认证
2015/04/09 PHP
php写入mysql中文乱码的实例解决方法
2019/09/17 PHP
关于laravel模板中生成URL的几种模式总结
2019/10/18 PHP
jqPlot jquery的页面图表绘制工具
2009/07/25 Javascript
jquery isType() 类型判断代码
2011/02/14 Javascript
js分解url参数(面向对象-极简主义法应用)
2012/08/09 Javascript
全面解析JavaScript的Backbone.js框架中的Router路由
2016/05/05 Javascript
动态加载css方法实现和深入解析
2017/01/18 Javascript
js实现一个简单的数字时钟效果
2017/03/29 Javascript
vue 路由页面之间实现用手指进行滑动的方法
2018/02/23 Javascript
Vue CLI 3搭建vue+vuex最全分析(推荐)
2018/09/27 Javascript
vue实现文字横向无缝走马灯组件效果的实例代码
2019/04/09 Javascript
js实现全选反选不选功能代码详解
2019/04/24 Javascript
基于python中的TCP及UDP(详解)
2017/11/06 Python
python opencv之SURF算法示例
2018/02/24 Python
Pandas对DataFrame单列/多列进行运算(map, apply, transform, agg)
2020/06/14 Python
在Pytorch中使用Mask R-CNN进行实例分割操作
2020/06/24 Python
CSS3 制作旋转的大风车(充满童年回忆)
2013/01/30 HTML / CSS
CSS3中使用RGBA设置透明度的示例
2015/08/04 HTML / CSS
阿拉伯世界最大的电子商务网站:Souq沙特阿拉伯
2016/10/28 全球购物
Ruby如何定义一个类
2012/10/08 面试题
自动化工程专业个人应聘自荐信
2013/09/26 职场文书
高中毕业自我鉴定
2013/12/13 职场文书
考试违纪检讨书
2014/02/02 职场文书
自我鉴定总结
2014/03/24 职场文书
我的长生果教学反思
2014/04/28 职场文书
学生会竞选演讲稿学习部
2014/08/25 职场文书
计划生育证明格式及范本
2014/10/09 职场文书
四风专项整治工作情况汇报
2014/10/28 职场文书
2014年度个人工作总结
2014/11/07 职场文书
股权转让协议书
2014/12/07 职场文书
2015年终个人政治思想工作总结
2015/11/24 职场文书
公司与个人合作协议书
2016/03/19 职场文书
Java实现贪吃蛇游戏的示例代码
2022/09/23 Java/Android