使用tensorflow实现VGG网络,训练mnist数据集方式


Posted in Python onMay 26, 2020

VGG作为流行的几个模型之一,训练图形数据效果不错,在mnist数据集是常用的入门集数据,VGG层数非常多,如果严格按照规范来实现,并用来训练mnist数据集,会出现各种问题,如,经过16层卷积后,28*28*1的图片几乎无法进行。

先介绍下VGG

ILSVRC 2014的第二名是Karen Simonyan和 Andrew Zisserman实现的卷积神经网络,现在称其为VGGNet。它主要的贡献是展示出网络的深度是算法优良性能的关键部分。

他们最好的网络包含了16个卷积/全连接层。网络的结构非常一致,从头到尾全部使用的是3x3的卷积和2x2的汇聚。他们的预训练模型是可以在网络上获得并在Caffe中使用的。

VGGNet不好的一点是它耗费更多计算资源,并且使用了更多的参数,导致更多的内存占用(140M)。其中绝大多数的参数都是来自于第一个全连接层。

模型结构:

使用tensorflow实现VGG网络,训练mnist数据集方式

本文在实现时候,尽量保存VGG原来模型结构,核心代码如下:

weights ={
  'wc1':tf.Variable(tf.random_normal([3,3,1,64])),
  'wc2':tf.Variable(tf.random_normal([3,3,64,64])),
  'wc3':tf.Variable(tf.random_normal([3,3,64,128])),
  'wc4':tf.Variable(tf.random_normal([3,3,128,128])),
  
  'wc5':tf.Variable(tf.random_normal([3,3,128,256])),
  'wc6':tf.Variable(tf.random_normal([3,3,256,256])),
  'wc7':tf.Variable(tf.random_normal([3,3,256,256])),
  'wc8':tf.Variable(tf.random_normal([3,3,256,256])),
  
  'wc9':tf.Variable(tf.random_normal([3,3,256,512])),
  'wc10':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc11':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc12':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc13':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc14':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc15':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc16':tf.Variable(tf.random_normal([3,3,512,256])),
  
  'wd1':tf.Variable(tf.random_normal([4096,4096])),
  'wd2':tf.Variable(tf.random_normal([4096,4096])),
  'out':tf.Variable(tf.random_normal([4096,nn_classes])),
}
 
biases ={
  'bc1':tf.Variable(tf.zeros([64])),
  'bc2':tf.Variable(tf.zeros([64])),
  'bc3':tf.Variable(tf.zeros([128])),
  'bc4':tf.Variable(tf.zeros([128])),
  'bc5':tf.Variable(tf.zeros([256])),
  'bc6':tf.Variable(tf.zeros([256])),
  'bc7':tf.Variable(tf.zeros([256])),
  'bc8':tf.Variable(tf.zeros([256])),
  'bc9':tf.Variable(tf.zeros([512])),
  'bc10':tf.Variable(tf.zeros([512])),
  'bc11':tf.Variable(tf.zeros([512])),
  'bc12':tf.Variable(tf.zeros([512])),
  'bc13':tf.Variable(tf.zeros([512])),
  'bc14':tf.Variable(tf.zeros([512])),
  'bc15':tf.Variable(tf.zeros([512])),
  'bc16':tf.Variable(tf.zeros([256])),
  
  
  'bd1':tf.Variable(tf.zeros([4096])),
  'bd2':tf.Variable(tf.zeros([4096])),
  'out':tf.Variable(tf.zeros([nn_classes])),
}

卷积实现:

def convLevel(i,input,type):
  num = i
  out = conv2D('conv'+str(num),input,weights['wc'+str(num)],biases['bc'+str(num)])
  if type=='p':
    out = maxPool2D('pool'+str(num),out, k=2) 
    out = norm('norm'+str(num),out, lsize=4)
  return out 
 
def VGG(x,weights,biases,dropout):
  x = tf.reshape(x,shape=[-1,28,28,1])
 
  input = x
 
  for i in range(16):
    i += 1
    if(i==2) or (i==4) or (i==12) : # 根据模型定义还需要更多的POOL化,但mnist图片大小不允许。
      input = convLevel(i,input,'p')
    else:
      input = convLevel(i,input,'c')

训练:

pred = VGG(x, weights, biases, keep_prob)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred,labels=y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)
 
correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
accuracy_ = tf.reduce_mean(tf.cast(correct_pred,tf.float32))
 
init = tf.global_variables_initializer()
with tf.Session() as sess:
  sess.run(init)
  step = 1
  while step*batch_size < train_iters:
    batch_x,batch_y = mnist.train.next_batch(batch_size)
    sess.run(optimizer,feed_dict={x:batch_x,y:batch_y,keep_prob:dropout})
    print(step*batch_size)
    if step % display_step == 0 :
      #loss,acc = sess.run([cost,accuracy],feed_dict={x:batch_x,y:batch_y,keep_prob=1.0})
      acc = sess.run(accuracy_, feed_dict={x: batch_x, y: batch_y, keep_prob: 1.})
      # 计算损失值
      
      loss = sess.run(cost, feed_dict={x: batch_x, y: batch_y, keep_prob: 1.})
      print("iter: "+str(step*batch_size)+"mini batch Loss="+"{:.6f}".format(loss)+",acc="+"{:6f}".format(acc))
 
    step += 1 
   
  print("training end!")

最终效果:

训练10000次后:结果如下:

iter: 12288 mini batch Loss=5088409.500000,acc=0.578125

iter: 12800 mini batch Loss=4514274.000000,acc=0.601562

iter: 13312 mini batch Loss=4483454.500000,acc=0.648438

这种深度的模型可以考虑循环10万次以上。目前效果还不错,本人没有GPU,心痛笔记本的CPU,100%的CPU利用率,听到风扇响就不忍心再训练,本文也借鉴了alex网络实现,当然我也实现了这个网络模型。在MNIST数据上,ALEX由于层数较少,收敛更快,当然MNIST,用CNN足够了。

以上这篇使用tensorflow实现VGG网络,训练mnist数据集方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python面向对象特殊成员
Apr 24 Python
pandas中Timestamp类用法详解
Dec 11 Python
Python中使用支持向量机SVM实践
Dec 27 Python
Python实现的三层BP神经网络算法示例
Feb 07 Python
python xlsxwriter库生成图表的应用示例
Mar 16 Python
tensorflow 中对数组元素的操作方法
Jul 27 Python
python的pygal模块绘制反正切函数图像方法
Jul 16 Python
Pytorch 计算误判率,计算准确率,计算召回率的例子
Jan 18 Python
Python爬虫库BeautifulSoup获取对象(标签)名,属性,内容,注释
Jan 25 Python
python 使用建议与技巧分享(四)
Aug 18 Python
python爬虫中PhantomJS加载页面的实例方法
Nov 12 Python
python中Pyqt5使用Qlabel标签播放视频
Apr 22 Python
浅谈Tensorflow加载Vgg预训练模型的几个注意事项
May 26 #Python
Tensorflow加载Vgg预训练模型操作
May 26 #Python
PyQt5如何将.ui文件转换为.py文件的实例代码
May 26 #Python
TensorFlow实现模型断点训练,checkpoint模型载入方式
May 26 #Python
python 日志模块 日志等级设置失效的解决方案
May 26 #Python
python3.7+selenium模拟淘宝登录功能的实现
May 26 #Python
TensorFlow固化模型的实现操作
May 26 #Python
You might like
php文件上传表单摘自drupal的代码
2011/02/15 PHP
ThinkPHP中Widget扩展的两种写法及调用方法详解
2017/05/04 PHP
php验证码生成器
2017/05/24 PHP
Laravel 5.5基于内置的Auth模块实现前后台登陆详解
2017/12/21 PHP
PHP PDOStatement::bindParam讲解
2019/01/30 PHP
设定php简写功能的方法
2019/11/28 PHP
PHP Beanstalkd消息队列的安装与使用方法实例详解
2020/02/21 PHP
Jquery Select操作方法集合脚本之家特别版
2010/05/17 Javascript
Fastest way to build an HTML string(拼装html字符串的最快方法)
2011/08/20 Javascript
js 距离某一时间点时间是多少实现代码
2013/10/14 Javascript
JavaScript截取指定长度字符串点击可以展开全部代码
2015/12/04 Javascript
JS中跨页面调用变量和函数的方法(例如a.js 和 b.js中互相调用)
2016/11/01 Javascript
js原生实现FastClick事件的实例
2016/11/20 Javascript
JavaScript满天星导航栏实现方法
2018/03/08 Javascript
如何把vuejs打包出来的文件整合到springboot里
2018/07/26 Javascript
ant design vue嵌套表格及表格内部编辑的用法说明
2020/10/28 Javascript
基于vue项目设置resolves.alias: '@'路径并适配webstorm
2020/12/02 Vue.js
vue中watch的用法汇总
2020/12/28 Vue.js
[01:01:23]完美世界DOTA2联赛PWL S2 Forest vs FTD.C 第一场 11.26
2020/11/30 DOTA
python实现划词翻译
2020/04/23 Python
Python中的自定义函数学习笔记
2014/09/23 Python
使用NumPy和pandas对CSV文件进行写操作的实例
2018/06/14 Python
Python线程threading模块用法详解
2020/02/26 Python
python中字符串的编码与解码详析
2020/12/03 Python
Python爬虫+Tkinter制作一个翻译软件的示例
2021/02/20 Python
利用Bootstrap实现漂亮简洁的CSS3价格表实例源码
2017/03/02 HTML / CSS
HTML 5 input placeholder 属性如何完美兼任ie
2014/05/12 HTML / CSS
英国领先的互联网葡萄酒礼品商:Vintage Wine & Port
2019/05/24 全球购物
什么是smarty? Smarty的优点是什么?
2013/08/11 面试题
如何使用PHP session
2015/04/21 面试题
说说你所熟悉或听说过的j2ee中的几种常用模式?及对设计模式的一些看法
2012/05/24 面试题
建筑工程专业学生的自我评价
2013/12/25 职场文书
大学生2014全国两会学习心得体会
2014/03/10 职场文书
初三学习计划书范文
2014/04/30 职场文书
毕业论文答辩稿范文
2015/06/23 职场文书
win11无法添加打印机怎么办? 提示windows无法打开添加打印机的解决办法
2022/04/05 数码科技