pytorch实现CNN卷积神经网络


Posted in Python onFebruary 19, 2020

本文为大家讲解了pytorch实现CNN卷积神经网络,供大家参考,具体内容如下

我对卷积神经网络的一些认识

    卷积神经网络是时下最为流行的一种深度学习网络,由于其具有局部感受野等特性,让其与人眼识别图像具有相似性,因此被广泛应用于图像识别中,本人是研究机械故障诊断方面的,一般利用旋转机械的振动信号作为数据。

    对一维信号,通常采取的方法有两种,第一,直接对其做一维卷积,第二,反映到时频图像上,这就变成了图像识别,此前一直都在利用keras搭建网络,最近学了pytroch搭建cnn的方法,进行一下代码的尝试。所用数据为经典的minist手写字体数据集

import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt
`EPOCH = 1
BATCH_SIZE = 50
LR = 0.001
DOWNLOAD_MNIST = True

从网上下载数据集:

```python
train_data = torchvision.datasets.MNIST(
 root="./mnist/",
 train = True,
 transform=torchvision.transforms.ToTensor(),
 download = DOWNLOAD_MNIST,
)

print(train_data.train_data.size())
print(train_data.train_labels.size())

```plt.imshow(train_data.train_data[0].numpy(), cmap='autumn')
plt.title("%i" % train_data.train_labels[0])
plt.show()

train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)

test_data = torchvision.datasets.MNIST(root="./mnist/", train=False)
test_x = torch.unsqueeze(test_data.test_data, dim=1).type(torch.FloatTensor)[:2000]/255.

test_y = test_data.test_labels[:2000]


class CNN(nn.Module):
 def __init__(self):
  super(CNN, self).__init__()
  self.conv1 = nn.Sequential(
   nn.Conv2d(
    in_channels=1,
    out_channels=16,
    kernel_size=5,
    stride=1,
    padding=2,
   ),
   
   nn.ReLU(),
   nn.MaxPool2d(kernel_size=2),
  )
  
  self.conv2 = nn.Sequential(
   nn.Conv2d(16, 32, 5, 1, 2),
   nn.ReLU(),
   nn.MaxPool2d(2),
  )
  
  self.out = nn.Linear(32*7*7, 10) # fully connected layer, output 10 classes
  
 def forward(self, x):
  x = self.conv1(x)
  x = self.conv2(x)
  x = x.view(x.size(0), -1) # flatten the output of conv2 to (batch_size, 32*7*7)
  output = self.out(x)
  return output
 
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)
loss_func = nn.CrossEntropyLoss()
 
 from matplotlib import cm
try: from sklearn.manifold import TSNE; HAS_SK = True
except: HAS_SK = False; print('Please install sklearn for layer visualization')
def plot_with_labels(lowDWeights, labels):
 plt.cla()
 X, Y = lowDWeights[:, 0], lowDWeights[:, 1]
 for x, y, s in zip(X, Y, labels):
  c = cm.rainbow(int(255 * s / 9)); plt.text(x, y, s, backgroundcolor=c, fontsize=9)
 plt.xlim(X.min(), X.max()); plt.ylim(Y.min(), Y.max()); plt.title('Visualize last layer'); plt.show(); plt.pause(0.01)

plt.ion()

for epoch in range(EPOCH):
 for step, (b_x, b_y) in enumerate(train_loader):
  output = cnn(b_x)
  loss = loss_func(output, b_y)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
  
  if step % 50 == 0:
   test_output = cnn(test_x)
   pred_y = torch.max(test_output, 1)[1].data.numpy()
   accuracy = float((pred_y == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0))
   print("Epoch: ", epoch, "| train loss: %.4f" % loss.data.numpy(), 
     "| test accuracy: %.2f" % accuracy)
   
plt.ioff()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python输出当前目录下index.html文件路径的方法
Apr 28 Python
python抓取百度首页的方法
May 19 Python
Python守护进程用法实例分析
Jun 04 Python
python3实现基于用户的协同过滤
May 31 Python
深入浅析Python2.x和3.x版本的主要区别
Nov 30 Python
python实现向微信用户发送每日一句 python实现微信聊天机器人
Mar 27 Python
利用anaconda保证64位和32位的python共存
Mar 09 Python
python join方法使用详解
Jul 30 Python
Python中输入和输出(打印)数据实例方法
Oct 13 Python
win10系统Anaconda和Pycharm的Tensorflow2.0之CPU和GPU版本安装教程
Dec 03 Python
k-means & DBSCAN 总结
Apr 27 Python
Python数据分析之pandas读取数据
Jun 02 Python
python+opencv3生成一个自定义纯色图教程
Feb 19 #Python
Python 实现Image和Ndarray互相转换
Feb 19 #Python
python3+opencv生成不规则黑白mask实例
Feb 19 #Python
使用celery和Django处理异步任务的流程分析
Feb 19 #Python
Python Numpy,mask图像的生成详解
Feb 19 #Python
浅谈图像处理中掩膜(mask)的意义
Feb 19 #Python
Python中logging日志库实例详解
Feb 19 #Python
You might like
什么是MVC,好东西啊
2007/05/03 PHP
PHP获取http请求的头信息实现步骤
2012/12/16 PHP
Zend的MVC机制使用分析(二)
2013/05/02 PHP
PHP+MYSQL会员系统的开发实例教程
2014/08/23 PHP
Yii扩展组件编写方法实例分析
2015/06/29 PHP
Netbeans 8.2将支持PHP7 更精彩
2016/06/13 PHP
php求斐波那契数的两种实现方式【递归与递推】
2019/09/09 PHP
用js实现下载远程文件并保存在本地的脚本
2008/05/06 Javascript
js调试系列 初识控制台
2014/06/18 Javascript
javascript操作ul中li的方法
2015/05/14 Javascript
jquery表单验证需要做些什么
2015/11/17 Javascript
4种JavaScript实现简单tab选项卡切换的方法
2016/01/06 Javascript
AngularJS ng-controller 指令简单实例
2016/08/01 Javascript
JavaScript实现DOM对象选择器
2016/09/24 Javascript
分类解析jQuery选择器
2016/11/23 Javascript
解决在vue+webpack开发中出现两个或多个菜单公用一个组件问题
2017/11/28 Javascript
javaScript字符串工具类StringUtils详解
2017/12/08 Javascript
将Sublime Text 3 添加到右键中的简单方法
2017/12/12 Javascript
如何用JavaScript实现功能齐全的单链表详解
2019/02/11 Javascript
javascript 高级语法之继承的基本使用方法示例
2019/11/11 Javascript
vue 通过base64实现图片下载功能
2020/12/19 Vue.js
Python中条件选择和循环语句使用方法介绍
2013/03/13 Python
Python中的进程分支fork和exec详解
2015/04/11 Python
浅析Python的web.py框架中url的设定方法
2016/07/11 Python
浅析Python中yield关键词的作用与用法
2016/11/29 Python
python正则实现计算器功能
2017/12/14 Python
python中for用来遍历range函数的方法
2018/06/08 Python
python读取ini配置的类封装代码实例
2020/01/08 Python
使用Python来做一个屏幕录制工具的操作代码
2020/01/18 Python
解决pyecharts运行后产生的html文件用浏览器打开空白
2020/03/11 Python
Jupyter Notebook 远程访问配置详解
2021/01/11 Python
今天学到的CSS最新技术(与图片背景相关)
2012/12/24 HTML / CSS
html5生成柱状图(条形图)效果的实例代码
2016/03/25 HTML / CSS
DERMAdoctor官网:美国著名皮肤护理品牌
2019/07/06 全球购物
Java中会存在内存泄漏吗,请简单描述
2016/12/22 面试题
优秀团员事迹材料2000字
2014/08/20 职场文书