pytorch实现CNN卷积神经网络


Posted in Python onFebruary 19, 2020

本文为大家讲解了pytorch实现CNN卷积神经网络,供大家参考,具体内容如下

我对卷积神经网络的一些认识

    卷积神经网络是时下最为流行的一种深度学习网络,由于其具有局部感受野等特性,让其与人眼识别图像具有相似性,因此被广泛应用于图像识别中,本人是研究机械故障诊断方面的,一般利用旋转机械的振动信号作为数据。

    对一维信号,通常采取的方法有两种,第一,直接对其做一维卷积,第二,反映到时频图像上,这就变成了图像识别,此前一直都在利用keras搭建网络,最近学了pytroch搭建cnn的方法,进行一下代码的尝试。所用数据为经典的minist手写字体数据集

import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt
`EPOCH = 1
BATCH_SIZE = 50
LR = 0.001
DOWNLOAD_MNIST = True

从网上下载数据集:

```python
train_data = torchvision.datasets.MNIST(
 root="./mnist/",
 train = True,
 transform=torchvision.transforms.ToTensor(),
 download = DOWNLOAD_MNIST,
)

print(train_data.train_data.size())
print(train_data.train_labels.size())

```plt.imshow(train_data.train_data[0].numpy(), cmap='autumn')
plt.title("%i" % train_data.train_labels[0])
plt.show()

train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)

test_data = torchvision.datasets.MNIST(root="./mnist/", train=False)
test_x = torch.unsqueeze(test_data.test_data, dim=1).type(torch.FloatTensor)[:2000]/255.

test_y = test_data.test_labels[:2000]


class CNN(nn.Module):
 def __init__(self):
  super(CNN, self).__init__()
  self.conv1 = nn.Sequential(
   nn.Conv2d(
    in_channels=1,
    out_channels=16,
    kernel_size=5,
    stride=1,
    padding=2,
   ),
   
   nn.ReLU(),
   nn.MaxPool2d(kernel_size=2),
  )
  
  self.conv2 = nn.Sequential(
   nn.Conv2d(16, 32, 5, 1, 2),
   nn.ReLU(),
   nn.MaxPool2d(2),
  )
  
  self.out = nn.Linear(32*7*7, 10) # fully connected layer, output 10 classes
  
 def forward(self, x):
  x = self.conv1(x)
  x = self.conv2(x)
  x = x.view(x.size(0), -1) # flatten the output of conv2 to (batch_size, 32*7*7)
  output = self.out(x)
  return output
 
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)
loss_func = nn.CrossEntropyLoss()
 
 from matplotlib import cm
try: from sklearn.manifold import TSNE; HAS_SK = True
except: HAS_SK = False; print('Please install sklearn for layer visualization')
def plot_with_labels(lowDWeights, labels):
 plt.cla()
 X, Y = lowDWeights[:, 0], lowDWeights[:, 1]
 for x, y, s in zip(X, Y, labels):
  c = cm.rainbow(int(255 * s / 9)); plt.text(x, y, s, backgroundcolor=c, fontsize=9)
 plt.xlim(X.min(), X.max()); plt.ylim(Y.min(), Y.max()); plt.title('Visualize last layer'); plt.show(); plt.pause(0.01)

plt.ion()

for epoch in range(EPOCH):
 for step, (b_x, b_y) in enumerate(train_loader):
  output = cnn(b_x)
  loss = loss_func(output, b_y)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
  
  if step % 50 == 0:
   test_output = cnn(test_x)
   pred_y = torch.max(test_output, 1)[1].data.numpy()
   accuracy = float((pred_y == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0))
   print("Epoch: ", epoch, "| train loss: %.4f" % loss.data.numpy(), 
     "| test accuracy: %.2f" % accuracy)
   
plt.ioff()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中的闭包实例详解
Aug 29 Python
python使用正则表达式替换匹配成功的组并输出替换的次数
Nov 22 Python
同时安装Python2 & Python3 cmd下版本自由选择的方法
Dec 09 Python
python批量设置多个Excel文件页眉页脚的脚本
Mar 14 Python
python实现类之间的方法互相调用
Apr 29 Python
PyTorch的深度学习入门之PyTorch安装和配置
Jun 27 Python
解决Python3 控制台输出InsecureRequestWarning问题
Jul 15 Python
python3 深浅copy对比详解
Aug 12 Python
Laravel框架表单验证格式化输出的方法
Sep 25 Python
python 实现dict转json并保存文件
Dec 05 Python
python实现指定ip端口扫描方式
Dec 17 Python
从np.random.normal()到正态分布的拟合操作
Jun 02 Python
python+opencv3生成一个自定义纯色图教程
Feb 19 #Python
Python 实现Image和Ndarray互相转换
Feb 19 #Python
python3+opencv生成不规则黑白mask实例
Feb 19 #Python
使用celery和Django处理异步任务的流程分析
Feb 19 #Python
Python Numpy,mask图像的生成详解
Feb 19 #Python
浅谈图像处理中掩膜(mask)的意义
Feb 19 #Python
Python中logging日志库实例详解
Feb 19 #Python
You might like
php while循环得到循环次数
2013/10/26 PHP
从零开始学YII2框架(六)高级应用程序模板
2014/08/20 PHP
PHP中strnatcmp()函数“自然排序算法”进行字符串比较用法分析(对比strcmp函数)
2016/01/07 PHP
使用ltrace工具跟踪PHP库函数调用的方法
2016/04/25 PHP
php实现文件上传基本验证
2020/03/04 PHP
jquery创建表格(自动增加表格)代码分享
2013/12/25 Javascript
轻松创建nodejs服务器(9):实现非阻塞操作
2014/12/18 NodeJs
javascript与jquery中的this关键字用法实例分析
2015/12/24 Javascript
浅述Javascript的外部对象
2016/12/07 Javascript
JavaScript实现二分查找实例代码
2017/02/22 Javascript
JS与jQuery实现ListBox上移,下移,左移,右移操作功能示例
2018/05/31 jQuery
详解Vue iview IE浏览器不兼容报错(Iview Bable polyfill)
2019/01/07 Javascript
Seajs源码详解分析
2019/04/02 Javascript
JS实现扫码枪扫描二维码功能
2020/01/03 Javascript
Angular 多模块项目构建过程
2020/02/13 Javascript
JS中类的静态方法,静态变量,实例方法,实例变量区别与用法实例分析
2020/03/14 Javascript
在Python的循环体中使用else语句的方法
2015/03/30 Python
python使用calendar输出指定年份全年日历的方法
2015/04/04 Python
Python3调用微信企业号API发送文本消息代码示例
2017/11/10 Python
Python3实现购物车功能
2018/04/18 Python
对Python+opencv将图片生成视频的实例详解
2019/01/08 Python
python的turtle库使用详解
2019/05/10 Python
Python递归及尾递归优化操作实例分析
2020/02/01 Python
Selenium Webdriver元素定位的八种常用方式(小结)
2021/01/13 Python
新西兰床上用品和家居用品购物网站:Adairs
2018/04/27 全球购物
美国战术品牌:5.11 Tactical
2019/05/01 全球购物
如何开启linux的ssh服务
2015/02/14 面试题
竞聘上岗演讲稿范文
2014/01/10 职场文书
校园十大歌手策划书
2014/02/01 职场文书
英语教师自荐信
2014/05/26 职场文书
教师个人培训总结
2015/02/11 职场文书
2015年三年级班主任工作总结
2015/05/21 职场文书
环保主题班会教案
2015/08/13 职场文书
golang特有程序结构入门教程
2021/06/02 Python
php实例化对象的实例方法
2021/11/17 PHP
Python+OpenCV实现图片中的圆形检测
2022/04/07 Python