pytorch + visdom CNN处理自建图片数据集的方法


Posted in Python onJune 04, 2018

环境

系统:win10

cpu:i7-6700HQ

gpu:gtx965m

python : 3.6

pytorch :0.3

数据下载

来源自Sasank Chilamkurthy 的教程; 数据:下载链接。

下载后解压放到项目根目录:

pytorch + visdom CNN处理自建图片数据集的方法 

数据集为用来分类 蚂蚁和蜜蜂。有大约120个训练图像,每个类有75个验证图像。

数据导入

可以使用 torchvision.datasets.ImageFolder(root,transforms) 模块 可以将 图片转换为 tensor。

先定义transform:

ata_transforms = {
  'train': transforms.Compose([
    # 随机切成224x224 大小图片 统一图片格式
    transforms.RandomResizedCrop(224),
    # 图像翻转
    transforms.RandomHorizontalFlip(),
    # totensor 归一化(0,255) >> (0,1)  normalize  channel=(channel-mean)/std
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  ]),
  "val" : transforms.Compose([
    # 图片大小缩放 统一图片格式
    transforms.Resize(256),
    # 以中心裁剪
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  ])
}

导入,加载数据:

data_dir = './hymenoptera_data'
# trans data
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']}
# load data
data_loaders = {x: DataLoader(image_datasets[x], batch_size=BATCH_SIZE, shuffle=True) for x in ['train', 'val']}

data_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
print(data_sizes, class_names)

{'train': 244, 'val': 153} ['ants', 'bees']

训练集 244图片 , 测试集153图片 。

可视化部分图片看看,由于visdom支持tensor输入 ,不用换成numpy,直接用tensor计算即可 :

inputs, classes = next(iter(data_loaders['val']))

out = torchvision.utils.make_grid(inputs)
inp = torch.transpose(out, 0, 2)
mean = torch.FloatTensor([0.485, 0.456, 0.406])
std = torch.FloatTensor([0.229, 0.224, 0.225])
inp = std * inp + mean
inp = torch.transpose(inp, 0, 2)
viz.images(inp)

pytorch + visdom CNN处理自建图片数据集的方法

创建CNN

net 根据上一篇的处理cifar10的改了一下规格:

class CNN(nn.Module):
  def __init__(self, in_dim, n_class):
    super(CNN, self).__init__()
    self.cnn = nn.Sequential(
      nn.BatchNorm2d(in_dim),
      nn.ReLU(True),
      nn.Conv2d(in_dim, 16, 7), # 224 >> 218
      nn.BatchNorm2d(16),
      nn.ReLU(inplace=True),
      nn.MaxPool2d(2, 2), # 218 >> 109
      nn.ReLU(True),
      nn.Conv2d(16, 32, 5), # 105
      nn.BatchNorm2d(32),
      nn.ReLU(True),
      nn.Conv2d(32, 64, 5), # 101
      nn.BatchNorm2d(64),
      nn.ReLU(True),
      nn.Conv2d(64, 64, 3, 1, 1),
      nn.BatchNorm2d(64),
      nn.ReLU(True),
      nn.MaxPool2d(2, 2), # 101 >> 50
      nn.Conv2d(64, 128, 3, 1, 1), #
      nn.BatchNorm2d(128),
      nn.ReLU(True),
      nn.MaxPool2d(3), # 50 >> 16
    )
    self.fc = nn.Sequential(
      nn.Linear(128*16*16, 120),
      nn.BatchNorm1d(120),
      nn.ReLU(True),
      nn.Linear(120, n_class))
  def forward(self, x):
    out = self.cnn(x)
    out = self.fc(out.view(-1, 128*16*16))
    return out

# 输入3层rgb ,输出 分类 2    
model = CNN(3, 2)

loss,优化函数:

line = viz.line(Y=np.arange(10))
loss_f = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=LR, momentum=0.9)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

参数:

BATCH_SIZE = 4
LR = 0.001
EPOCHS = 10

运行 10个 epoch 看看:

[9/10] train_loss:0.650|train_acc:0.639|test_loss:0.621|test_acc0.706
[10/10] train_loss:0.645|train_acc:0.627|test_loss:0.654|test_acc0.686
Training complete in 1m 16s
Best val Acc: 0.712418

pytorch + visdom CNN处理自建图片数据集的方法

运行 20个看看:

[19/20] train_loss:0.592|train_acc:0.701|test_loss:0.563|test_acc0.712
[20/20] train_loss:0.564|train_acc:0.721|test_loss:0.571|test_acc0.706
Training complete in 2m 30s
Best val Acc: 0.745098

pytorch + visdom CNN处理自建图片数据集的方法

准确率比较低:只有74.5%

我们使用models 里的 resnet18 运行 10个epoch:

model = torchvision.models.resnet18(True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)

[9/10] train_loss:0.621|train_acc:0.652|test_loss:0.588|test_acc0.667
[10/10] train_loss:0.610|train_acc:0.680|test_loss:0.561|test_acc0.667
Training complete in 1m 24s
Best val Acc: 0.686275

效果也很一般,想要短时间内就训练出效果很好的models,我们可以下载训练好的state,在此基础上训练:

model = torchvision.models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)

[9/10] train_loss:0.308|train_acc:0.877|test_loss:0.160|test_acc0.941
[10/10] train_loss:0.267|train_acc:0.885|test_loss:0.148|test_acc0.954
Training complete in 1m 25s
Best val Acc: 0.954248

10个epoch直接的到95%的准确率。

pytorch + visdom CNN处理自建图片数据集的方法

示例代码:https://github.com/ffzs/ml_pytorch/blob/master/ml_pytorch_hymenoptera

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python re模块介绍
Nov 30 Python
Python 的内置字符串方法小结
Mar 15 Python
python脚本作为Windows服务启动代码详解
Feb 11 Python
Python3中bytes类型转换为str类型
Sep 27 Python
Python3爬虫学习之将爬取的信息保存到本地的方法详解
Dec 12 Python
python主线程与子线程的结束顺序实例解析
Dec 17 Python
Python3中configparser模块读写ini文件并解析配置的用法详解
Feb 18 Python
python实现3D地图可视化
Mar 25 Python
python列表的逆序遍历实现
Apr 20 Python
python 图像判断,清晰度(明暗),彩色与黑白实例
Jun 04 Python
浅谈pymysql查询语句中带有in时传递参数的问题
Jun 05 Python
Python socket如何解析HTTP请求内容
Feb 12 Python
python验证码识别教程之滑动验证码
Jun 04 #Python
python验证码识别教程之利用投影法、连通域法分割图片
Jun 04 #Python
python验证码识别教程之灰度处理、二值化、降噪与tesserocr识别
Jun 04 #Python
实用自动化运维Python脚本分享
Jun 04 #Python
python中验证码连通域分割的方法详解
Jun 04 #Python
python 匹配url中是否存在IP地址的方法
Jun 04 #Python
Python实现ping指定IP的示例
Jun 04 #Python
You might like
用PHP的ob_start() 控制您的浏览器cache
2009/08/03 PHP
PHP 图片上传实现代码 带详细注释
2010/04/29 PHP
PHP查询网站的PR值
2013/10/30 PHP
php实现curl模拟ftp上传的方法
2015/07/29 PHP
PHP常用技巧汇总
2016/03/04 PHP
Symfony2学习笔记之插件格式分析
2016/03/17 PHP
Laravel 实现在Blade模版中使用全局变量代替路径的例子
2019/10/22 PHP
JS的递增/递减运算符和带操作的赋值运算符的等价式
2007/12/08 Javascript
随窗体滑动的小插件sticky源码
2013/06/21 Javascript
JQuery操作iframe父页面与子页面的元素与方法(实例讲解)
2013/11/20 Javascript
js跑步算法的实现代码
2013/12/04 Javascript
在JS数组特定索引处指定位置插入元素
2014/07/27 Javascript
JS实现物体带缓冲的间歇运动效果示例
2016/12/22 Javascript
微信小程序列表中item左滑删除功能
2018/11/07 Javascript
如何使用 vue + d3 画一棵树
2018/12/03 Javascript
小程序扫描普通链接二维码跳转小程序指定界面方法
2019/05/07 Javascript
vue项目中全局引入1个.scss文件的问题解决
2019/08/01 Javascript
js实现GIF图片的分解和合成
2019/10/24 Javascript
如何基于JavaScript判断图片是否加载完成
2019/12/28 Javascript
[15:35]教你分分钟做大人:天怒法师
2014/10/30 DOTA
[59:59]EG vs IG 2018国际邀请赛小组赛BO2 第二场 8.16
2018/08/17 DOTA
pyqt4教程之messagebox使用示例分享
2014/03/07 Python
Django中更新多个对象数据与删除对象的方法
2015/07/17 Python
浅析python的优势和不足之处
2018/11/20 Python
使用python判断你是青少年还是老年人
2018/11/29 Python
pandas 使用均值填充缺失值列的小技巧分享
2019/07/04 Python
python画图——实现在图上标注上具体数值的方法
2019/07/08 Python
python实现简易学生信息管理系统
2020/04/05 Python
python 并发下载器实现方法示例
2019/11/22 Python
TCP/IP中的TCP和IP分别承担什么责任
2012/04/21 面试题
精彩的推荐信范文
2013/11/26 职场文书
银行员工职业规划范文
2014/01/21 职场文书
预备党员承诺书
2014/03/25 职场文书
2015年乡镇民政工作总结
2015/05/13 职场文书
关于vue中如何监听数组变化
2021/04/28 Vue.js
解决ubuntu安装软件时,status-code=409报错的问题
2022/12/24 Servers