pytorch实现CNN卷积神经网络


Posted in Python onFebruary 19, 2020

本文为大家讲解了pytorch实现CNN卷积神经网络,供大家参考,具体内容如下

我对卷积神经网络的一些认识

    卷积神经网络是时下最为流行的一种深度学习网络,由于其具有局部感受野等特性,让其与人眼识别图像具有相似性,因此被广泛应用于图像识别中,本人是研究机械故障诊断方面的,一般利用旋转机械的振动信号作为数据。

    对一维信号,通常采取的方法有两种,第一,直接对其做一维卷积,第二,反映到时频图像上,这就变成了图像识别,此前一直都在利用keras搭建网络,最近学了pytroch搭建cnn的方法,进行一下代码的尝试。所用数据为经典的minist手写字体数据集

import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt
`EPOCH = 1
BATCH_SIZE = 50
LR = 0.001
DOWNLOAD_MNIST = True

从网上下载数据集:

```python
train_data = torchvision.datasets.MNIST(
 root="./mnist/",
 train = True,
 transform=torchvision.transforms.ToTensor(),
 download = DOWNLOAD_MNIST,
)

print(train_data.train_data.size())
print(train_data.train_labels.size())

```plt.imshow(train_data.train_data[0].numpy(), cmap='autumn')
plt.title("%i" % train_data.train_labels[0])
plt.show()

train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)

test_data = torchvision.datasets.MNIST(root="./mnist/", train=False)
test_x = torch.unsqueeze(test_data.test_data, dim=1).type(torch.FloatTensor)[:2000]/255.

test_y = test_data.test_labels[:2000]


class CNN(nn.Module):
 def __init__(self):
  super(CNN, self).__init__()
  self.conv1 = nn.Sequential(
   nn.Conv2d(
    in_channels=1,
    out_channels=16,
    kernel_size=5,
    stride=1,
    padding=2,
   ),
   
   nn.ReLU(),
   nn.MaxPool2d(kernel_size=2),
  )
  
  self.conv2 = nn.Sequential(
   nn.Conv2d(16, 32, 5, 1, 2),
   nn.ReLU(),
   nn.MaxPool2d(2),
  )
  
  self.out = nn.Linear(32*7*7, 10) # fully connected layer, output 10 classes
  
 def forward(self, x):
  x = self.conv1(x)
  x = self.conv2(x)
  x = x.view(x.size(0), -1) # flatten the output of conv2 to (batch_size, 32*7*7)
  output = self.out(x)
  return output
 
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)
loss_func = nn.CrossEntropyLoss()
 
 from matplotlib import cm
try: from sklearn.manifold import TSNE; HAS_SK = True
except: HAS_SK = False; print('Please install sklearn for layer visualization')
def plot_with_labels(lowDWeights, labels):
 plt.cla()
 X, Y = lowDWeights[:, 0], lowDWeights[:, 1]
 for x, y, s in zip(X, Y, labels):
  c = cm.rainbow(int(255 * s / 9)); plt.text(x, y, s, backgroundcolor=c, fontsize=9)
 plt.xlim(X.min(), X.max()); plt.ylim(Y.min(), Y.max()); plt.title('Visualize last layer'); plt.show(); plt.pause(0.01)

plt.ion()

for epoch in range(EPOCH):
 for step, (b_x, b_y) in enumerate(train_loader):
  output = cnn(b_x)
  loss = loss_func(output, b_y)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
  
  if step % 50 == 0:
   test_output = cnn(test_x)
   pred_y = torch.max(test_output, 1)[1].data.numpy()
   accuracy = float((pred_y == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0))
   print("Epoch: ", epoch, "| train loss: %.4f" % loss.data.numpy(), 
     "| test accuracy: %.2f" % accuracy)
   
plt.ioff()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python内置的字符串处理函数整理
Jan 29 Python
Django中cookie的基本使用方法示例
Feb 03 Python
Python中的上下文管理器和with语句的使用
Apr 17 Python
Python使用分布式锁的代码演示示例
Jul 30 Python
python实现指定ip端口扫描方式
Dec 17 Python
python爬虫爬取监控教务系统的思路详解
Jan 08 Python
python3 Scrapy爬虫框架ip代理配置的方法
Jan 17 Python
python对execl 处理操作代码
Jun 22 Python
python使用布隆过滤器的实现示例
Aug 20 Python
利用python清除移动硬盘中的临时文件
Oct 28 Python
Python使用scapy模块发包收包
May 07 Python
Python可变集合和不可变集合的构造方法大全
Dec 06 Python
python+opencv3生成一个自定义纯色图教程
Feb 19 #Python
Python 实现Image和Ndarray互相转换
Feb 19 #Python
python3+opencv生成不规则黑白mask实例
Feb 19 #Python
使用celery和Django处理异步任务的流程分析
Feb 19 #Python
Python Numpy,mask图像的生成详解
Feb 19 #Python
浅谈图像处理中掩膜(mask)的意义
Feb 19 #Python
Python中logging日志库实例详解
Feb 19 #Python
You might like
PHP技术开发技巧分享
2010/03/23 PHP
thinkPHP实现将excel导入到数据库中的方法
2016/04/22 PHP
PHP实现的装箱算法示例
2018/06/23 PHP
侧栏跟随滚动的简单实现代码
2013/03/18 Javascript
博客侧边栏模块跟随滚动条滑动固定效果的实现方法(js+jquery等)
2013/03/24 Javascript
jquery实现页面图片等比例放大缩小功能
2014/02/12 Javascript
Jquery使用val方法读写value值
2015/05/18 Javascript
javascript实现行拖动的方法
2015/05/27 Javascript
JS实现无限级网页折叠菜单(类似树形菜单)效果代码
2015/09/17 Javascript
浅谈js构造函数的方法与原型prototype
2016/07/04 Javascript
angularJS开发注意事项
2018/05/26 Javascript
浅析js中mvvm模式实现的原理
2018/10/06 Javascript
动态创建类实例代码
2009/10/07 Python
连接Python程序与MySQL的教程
2015/04/29 Python
pyQt4实现俄罗斯方块游戏
2018/06/26 Python
python3+selenium实现qq邮箱登陆并发送邮件功能
2019/01/23 Python
Python面向对象程序设计中类的定义、实例化、封装及私有变量/方法详解
2019/02/28 Python
Python3.7安装keras和TensorFlow的教程图解
2020/06/18 Python
python 队列基本定义与使用方法【初始化、赋值、判断等】
2019/10/24 Python
Python 支持向量机分类器的实现
2020/01/15 Python
python:批量统计xml中各类目标的数量案例
2020/03/10 Python
html5唤醒APP小记
2019/03/27 HTML / CSS
美国著名童装品牌:OshKosh B’gosh
2016/08/05 全球购物
Perfumetrader荷兰:香水、化妆品和护肤品在线商店
2017/09/15 全球购物
递归计算如下递归函数的值(斐波拉契)
2012/02/04 面试题
27个经典Linux面试题及答案,你知道几个?
2014/03/11 面试题
区域销售经理职责
2013/12/22 职场文书
餐厅总经理岗位职责
2013/12/31 职场文书
中医专业职业生涯规划书范文
2014/01/04 职场文书
大学毕业生求职自荐书
2014/06/05 职场文书
公司合作意向书范文
2014/07/30 职场文书
2014年学校法制宣传日活动总结
2014/11/01 职场文书
解除同居协议书
2015/01/29 职场文书
小学大队干部竞选稿
2015/11/20 职场文书
创业计划书之养殖业
2019/10/11 职场文书
Django实现在线无水印抖音视频下载(附源码及地址)
2021/05/06 Python