使用tensorflow实现VGG网络,训练mnist数据集方式


Posted in Python onMay 26, 2020

VGG作为流行的几个模型之一,训练图形数据效果不错,在mnist数据集是常用的入门集数据,VGG层数非常多,如果严格按照规范来实现,并用来训练mnist数据集,会出现各种问题,如,经过16层卷积后,28*28*1的图片几乎无法进行。

先介绍下VGG

ILSVRC 2014的第二名是Karen Simonyan和 Andrew Zisserman实现的卷积神经网络,现在称其为VGGNet。它主要的贡献是展示出网络的深度是算法优良性能的关键部分。

他们最好的网络包含了16个卷积/全连接层。网络的结构非常一致,从头到尾全部使用的是3x3的卷积和2x2的汇聚。他们的预训练模型是可以在网络上获得并在Caffe中使用的。

VGGNet不好的一点是它耗费更多计算资源,并且使用了更多的参数,导致更多的内存占用(140M)。其中绝大多数的参数都是来自于第一个全连接层。

模型结构:

使用tensorflow实现VGG网络,训练mnist数据集方式

本文在实现时候,尽量保存VGG原来模型结构,核心代码如下:

weights ={
  'wc1':tf.Variable(tf.random_normal([3,3,1,64])),
  'wc2':tf.Variable(tf.random_normal([3,3,64,64])),
  'wc3':tf.Variable(tf.random_normal([3,3,64,128])),
  'wc4':tf.Variable(tf.random_normal([3,3,128,128])),
  
  'wc5':tf.Variable(tf.random_normal([3,3,128,256])),
  'wc6':tf.Variable(tf.random_normal([3,3,256,256])),
  'wc7':tf.Variable(tf.random_normal([3,3,256,256])),
  'wc8':tf.Variable(tf.random_normal([3,3,256,256])),
  
  'wc9':tf.Variable(tf.random_normal([3,3,256,512])),
  'wc10':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc11':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc12':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc13':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc14':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc15':tf.Variable(tf.random_normal([3,3,512,512])),
  'wc16':tf.Variable(tf.random_normal([3,3,512,256])),
  
  'wd1':tf.Variable(tf.random_normal([4096,4096])),
  'wd2':tf.Variable(tf.random_normal([4096,4096])),
  'out':tf.Variable(tf.random_normal([4096,nn_classes])),
}
 
biases ={
  'bc1':tf.Variable(tf.zeros([64])),
  'bc2':tf.Variable(tf.zeros([64])),
  'bc3':tf.Variable(tf.zeros([128])),
  'bc4':tf.Variable(tf.zeros([128])),
  'bc5':tf.Variable(tf.zeros([256])),
  'bc6':tf.Variable(tf.zeros([256])),
  'bc7':tf.Variable(tf.zeros([256])),
  'bc8':tf.Variable(tf.zeros([256])),
  'bc9':tf.Variable(tf.zeros([512])),
  'bc10':tf.Variable(tf.zeros([512])),
  'bc11':tf.Variable(tf.zeros([512])),
  'bc12':tf.Variable(tf.zeros([512])),
  'bc13':tf.Variable(tf.zeros([512])),
  'bc14':tf.Variable(tf.zeros([512])),
  'bc15':tf.Variable(tf.zeros([512])),
  'bc16':tf.Variable(tf.zeros([256])),
  
  
  'bd1':tf.Variable(tf.zeros([4096])),
  'bd2':tf.Variable(tf.zeros([4096])),
  'out':tf.Variable(tf.zeros([nn_classes])),
}

卷积实现:

def convLevel(i,input,type):
  num = i
  out = conv2D('conv'+str(num),input,weights['wc'+str(num)],biases['bc'+str(num)])
  if type=='p':
    out = maxPool2D('pool'+str(num),out, k=2) 
    out = norm('norm'+str(num),out, lsize=4)
  return out 
 
def VGG(x,weights,biases,dropout):
  x = tf.reshape(x,shape=[-1,28,28,1])
 
  input = x
 
  for i in range(16):
    i += 1
    if(i==2) or (i==4) or (i==12) : # 根据模型定义还需要更多的POOL化,但mnist图片大小不允许。
      input = convLevel(i,input,'p')
    else:
      input = convLevel(i,input,'c')

训练:

pred = VGG(x, weights, biases, keep_prob)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred,labels=y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)
 
correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
accuracy_ = tf.reduce_mean(tf.cast(correct_pred,tf.float32))
 
init = tf.global_variables_initializer()
with tf.Session() as sess:
  sess.run(init)
  step = 1
  while step*batch_size < train_iters:
    batch_x,batch_y = mnist.train.next_batch(batch_size)
    sess.run(optimizer,feed_dict={x:batch_x,y:batch_y,keep_prob:dropout})
    print(step*batch_size)
    if step % display_step == 0 :
      #loss,acc = sess.run([cost,accuracy],feed_dict={x:batch_x,y:batch_y,keep_prob=1.0})
      acc = sess.run(accuracy_, feed_dict={x: batch_x, y: batch_y, keep_prob: 1.})
      # 计算损失值
      
      loss = sess.run(cost, feed_dict={x: batch_x, y: batch_y, keep_prob: 1.})
      print("iter: "+str(step*batch_size)+"mini batch Loss="+"{:.6f}".format(loss)+",acc="+"{:6f}".format(acc))
 
    step += 1 
   
  print("training end!")

最终效果:

训练10000次后:结果如下:

iter: 12288 mini batch Loss=5088409.500000,acc=0.578125

iter: 12800 mini batch Loss=4514274.000000,acc=0.601562

iter: 13312 mini batch Loss=4483454.500000,acc=0.648438

这种深度的模型可以考虑循环10万次以上。目前效果还不错,本人没有GPU,心痛笔记本的CPU,100%的CPU利用率,听到风扇响就不忍心再训练,本文也借鉴了alex网络实现,当然我也实现了这个网络模型。在MNIST数据上,ALEX由于层数较少,收敛更快,当然MNIST,用CNN足够了。

以上这篇使用tensorflow实现VGG网络,训练mnist数据集方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python程序语言快速上手教程
Jul 18 Python
给Python初学者的一些编程技巧
Apr 03 Python
Python使用getpass库读取密码的示例
Oct 10 Python
Python Pandas找到缺失值的位置方法
Apr 12 Python
Python基于jieba库进行简单分词及词云功能实现方法
Jun 16 Python
Python字符串逆序的实现方法【一题多解】
Feb 18 Python
11个Python3字典内置方法大全与示例汇总
May 13 Python
Python3操作MongoDB增册改查等方法详解
Feb 10 Python
如何导出python安装的所有模块名称和版本号到文件中
Jun 05 Python
Python爬虫基于lxml解决数据编码乱码问题
Jul 31 Python
python使用多线程查询数据库的实现示例
Aug 17 Python
Python调用高德API实现批量地址转经纬度并写入表格的功能
Jan 12 Python
浅谈Tensorflow加载Vgg预训练模型的几个注意事项
May 26 #Python
Tensorflow加载Vgg预训练模型操作
May 26 #Python
PyQt5如何将.ui文件转换为.py文件的实例代码
May 26 #Python
TensorFlow实现模型断点训练,checkpoint模型载入方式
May 26 #Python
python 日志模块 日志等级设置失效的解决方案
May 26 #Python
python3.7+selenium模拟淘宝登录功能的实现
May 26 #Python
TensorFlow固化模型的实现操作
May 26 #Python
You might like
利用discuz自带通行证整合dedecms的方法以及文件下载
2007/03/06 PHP
yii框架表单模型使用及以数组形式提交表单数据示例
2014/04/30 PHP
PHP获取windows登录用户名的方法
2014/06/24 PHP
支持png透明图片的php生成缩略图类分享
2015/02/08 PHP
javascript 放大镜 v1.0 基于Yui2 实现的放大镜效果
2010/03/08 Javascript
jquery判断元素的子元素是否存在的示例代码
2014/02/04 Javascript
javascript的数组和常用函数详解
2014/05/09 Javascript
jQuery实现页面内锚点平滑跳转特效的方法总结
2015/05/11 Javascript
javascript如何操作HTML下拉列表标签
2015/08/20 Javascript
基于OL2实现百度地图ABCD marker的效果
2015/10/01 Javascript
js获取Html元素的实际宽度高度的方法
2016/05/19 Javascript
jQuery搜索框效果实现代码(百度关键词联想)
2021/02/25 Javascript
canvas绘制多边形
2017/02/24 Javascript
JavaScript中transform实现数字翻页效果
2017/03/08 Javascript
JavaScript仿微信打飞机游戏
2020/07/05 Javascript
JS检测是否可以访问公网服务器功能代码
2017/06/19 Javascript
BootStrap Table复选框默认选中功能的实现代码(从数据库获取到对应的状态进行判断是否为选中状态)
2017/07/11 Javascript
JavaScript:ES2019 的新特性(译)
2019/08/08 Javascript
解决layer 关闭当前弹窗 关闭遮罩层 input值获取不到的问题
2019/09/25 Javascript
Vue使用Proxy代理后仍无法生效的解决
2020/11/13 Javascript
[51:10]VP vs VGJ.S 2018国际邀请赛小组赛BO2 第二场 8.19
2018/08/21 DOTA
Python排序算法之选择排序定义与用法示例
2018/04/29 Python
Python3.5 Pandas模块缺失值处理和层次索引实例详解
2019/04/23 Python
python函数的万能参数传参详解
2019/07/26 Python
使用Python完成15位18位身份证的互转功能
2019/11/06 Python
Python 实现训练集、测试集随机划分
2020/01/08 Python
Python3批量创建Crowd用户并分配组
2020/05/20 Python
css3制作彩色边线3d立体按钮的示例(css3按钮)
2014/05/06 HTML / CSS
使用HTML5的链接预取功能(link prefetching)给网站提速
2012/12/13 HTML / CSS
英国排名第一的在线宠物用品商店:Monster Pet Supplies
2018/05/20 全球购物
英国皇家造币厂:The Royal Mint
2018/10/05 全球购物
英国领先的露营和露营车品牌之一:OLPRO
2019/08/06 全球购物
大学生村官心得体会范文
2014/01/04 职场文书
团员个人年度总结
2015/02/26 职场文书
法人身份证明书
2015/06/18 职场文书
《少年闰土》教学反思
2016/02/18 职场文书