pytorch GAN伪造手写体mnist数据集方式


Posted in Python onJanuary 10, 2020

一,mnist数据集

pytorch GAN伪造手写体mnist数据集方式

形如上图的数字手写体就是mnist数据集。

二,GAN原理(生成对抗网络)

GAN网络一共由两部分组成:一个是伪造器(Generator,简称G),一个是判别器(Discrimniator,简称D)

一开始,G由服从某几个分布(如高斯分布)的噪音组成,生成的图片不断送给D判断是否正确,直到G生成的图片连D都判断以为是真的。D每一轮除了看过G生成的假图片以外,还要见数据集中的真图片,以前者和后者得到的损失函数值为依据更新D网络中的权值。因此G和D都在不停地更新权值。以下图为例:

pytorch GAN伪造手写体mnist数据集方式

在v1时的G只不过是 一堆噪声,见过数据集(real images)的D肯定能判断出G所生成的是假的。当然G也能知道D判断它是假的这个结果,因此G就会更新权值,到v2的时候,G就能生成更逼真的图片来让D判断,当然在v2时D也是会先看一次真图片,再去判断G所生成的图片。以此类推,不断循环就是GAN的思想。

三,训练代码

import argparse
import os
import numpy as np
import math
 
import torchvision.transforms as transforms
from torchvision.utils import save_image
 
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
 
import torch.nn as nn
import torch.nn.functional as F
import torch
 
os.makedirs("images", exist_ok=True)
 
parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval betwen image samples")
opt = parser.parse_args()
print(opt)
 
img_shape = (opt.channels, opt.img_size, opt.img_size) # 确定图片输入的格式为(1,28,28),由于mnist数据集是灰度图所以通道为1
cuda = True if torch.cuda.is_available() else False
 
 
class Generator(nn.Module):
 def __init__(self):
  super(Generator, self).__init__()
 
  def block(in_feat, out_feat, normalize=True):
   layers = [nn.Linear(in_feat, out_feat)]
   if normalize:
    layers.append(nn.BatchNorm1d(out_feat, 0.8))
   layers.append(nn.LeakyReLU(0.2, inplace=True))
   return layers
 
  self.model = nn.Sequential(
   *block(opt.latent_dim, 128, normalize=False),
   *block(128, 256),
   *block(256, 512),
   *block(512, 1024),
   nn.Linear(1024, int(np.prod(img_shape))),
   nn.Tanh()
  )
 
 def forward(self, z):
  img = self.model(z)
  img = img.view(img.size(0), *img_shape)
  return img
 
 
class Discriminator(nn.Module):
 def __init__(self):
  super(Discriminator, self).__init__()
 
  self.model = nn.Sequential(
   nn.Linear(int(np.prod(img_shape)), 512),
   nn.LeakyReLU(0.2, inplace=True),
   nn.Linear(512, 256),
   nn.LeakyReLU(0.2, inplace=True),
   nn.Linear(256, 1),
   nn.Sigmoid(),
  )
 
 def forward(self, img):
  img_flat = img.view(img.size(0), -1)
  validity = self.model(img_flat)
  return validity
 
 
# Loss function
adversarial_loss = torch.nn.BCELoss()
 
# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()
 
if cuda:
 generator.cuda()
 discriminator.cuda()
 adversarial_loss.cuda()
 
# Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
 datasets.MNIST(
  "../../data/mnist",
  train=True,
  download=True,
  transform=transforms.Compose(
   [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
  ),
 ),
 batch_size=opt.batch_size,
 shuffle=True,
)
 
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
 
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
 
# ----------
# Training
# ----------
if __name__ == '__main__':
 for epoch in range(opt.n_epochs):
  for i, (imgs, _) in enumerate(dataloader):
   # print(imgs.shape)
   # Adversarial ground truths
   valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False) # 全1
   fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False) # 全0
   # Configure input
   real_imgs = Variable(imgs.type(Tensor))
 
   # -----------------
   # Train Generator
   # -----------------
 
   optimizer_G.zero_grad() # 清空G网络 上一个batch的梯度
 
   # Sample noise as generator input
   z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim)))) # 生成的噪音,均值为0方差为1维度为(64,100)的噪音
   # Generate a batch of images
   gen_imgs = generator(z)
   # Loss measures generator's ability to fool the discriminator
   g_loss = adversarial_loss(discriminator(gen_imgs), valid)
 
   g_loss.backward() # g_loss用于更新G网络的权值,g_loss于D网络的判断结果 有关
   optimizer_G.step()
 
   # ---------------------
   # Train Discriminator
   # ---------------------
 
   optimizer_D.zero_grad() # 清空D网络 上一个batch的梯度
   # Measure discriminator's ability to classify real from generated samples
   real_loss = adversarial_loss(discriminator(real_imgs), valid)
   fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
   d_loss = (real_loss + fake_loss) / 2
 
   d_loss.backward() # d_loss用于更新D网络的权值
   optimizer_D.step()
 
   print(
    "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
    % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
   )
 
   batches_done = epoch * len(dataloader) + i
   if batches_done % opt.sample_interval == 0:
    save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True) # 保存一个batchsize中的25张
   if (epoch+1) %2 ==0:
    print('save..')
    torch.save(generator,'g%d.pth' % epoch)
    torch.save(discriminator,'d%d.pth' % epoch)

运行结果:

一开始时,G生成的全是杂音:

pytorch GAN伪造手写体mnist数据集方式

然后逐渐呈现数字的雏形:

pytorch GAN伪造手写体mnist数据集方式

最后一次生成的结果:

pytorch GAN伪造手写体mnist数据集方式

四,测试代码:

导入最后保存生成器的模型:

from gan import Generator,Discriminator
import torch
import matplotlib.pyplot as plt
from torch.autograd import Variable
import numpy as np
from torchvision.utils import save_image
 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
Tensor = torch.cuda.FloatTensor
g = torch.load('g199.pth') #导入生成器Generator模型
#d = torch.load('d.pth')
g = g.to(device)
#d = d.to(device)
 
z = Variable(Tensor(np.random.normal(0, 1, (64, 100)))) #输入的噪音
gen_imgs =g(z) #生产图片
save_image(gen_imgs.data[:25], "images.png" , nrow=5, normalize=True)

生成结果:

pytorch GAN伪造手写体mnist数据集方式

以上这篇pytorch GAN伪造手写体mnist数据集方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中使用 Selenium 实现网页截图实例
Jul 18 Python
Python常见数据结构详解
Jul 24 Python
python if not in 多条件判断代码
Sep 21 Python
python爬虫爬取淘宝商品信息(selenum+phontomjs)
Feb 24 Python
python实现狄克斯特拉算法
Jan 17 Python
Python实现的IP端口扫描工具类示例
Feb 15 Python
Pandas DataFrame数据的更改、插入新增的列和行的方法
Jun 25 Python
Python Numpy计算各类距离的方法
Jul 05 Python
Apache部署Django项目图文详解
Jul 30 Python
python 实现在shell窗口中编写print不向屏幕输出
Feb 19 Python
简单了解Python字典copy与赋值的区别
Sep 16 Python
python 实现"神经衰弱"翻牌游戏
Nov 09 Python
MNIST数据集转化为二维图片的实现示例
Jan 10 #Python
pytorch:实现简单的GAN示例(MNIST数据集)
Jan 10 #Python
pytorch GAN生成对抗网络实例
Jan 10 #Python
解决pytorch报错:AssertionError: Invalid device id的问题
Jan 10 #Python
python3中关于excel追加写入格式被覆盖问题(实例代码)
Jan 10 #Python
mac使用python识别图形验证码功能
Jan 10 #Python
python列表推导和生成器表达式知识点总结
Jan 10 #Python
You might like
PHP操作文件类的函数代码(文件和文件夹创建,复制,移动和删除)
2011/11/10 PHP
[原创]php简单隔行变色功能实现代码
2016/07/09 PHP
PHP读取、解析eml文件及生成网页的方法示例
2017/09/04 PHP
激活 ActiveX 控件
2006/10/09 Javascript
JQuery 常用操作代码
2010/03/14 Javascript
javascript判断ie浏览器6/7版本加载不同样式表的实现代码
2011/12/26 Javascript
用RadioButten或CheckBox实现div的显示与隐藏
2013/09/21 Javascript
用JS在浏览器中创建下载文件
2014/03/05 Javascript
使用jQuery实现星级评分代码分享
2014/12/09 Javascript
JavaScript中的setUTCDate()方法使用详解
2015/06/11 Javascript
使用控制台破解百小度一个月只准改一次名字
2015/08/13 Javascript
jQuery数据类型小结(14个)
2016/01/08 Javascript
js基于cookie记录来宾姓名的方法
2016/07/19 Javascript
基于HTML+CSS+JS实现增加删除修改tab导航特效代码
2016/08/05 Javascript
bootstrap中的 form表单属性role="form"的作用详解
2017/01/20 Javascript
纯js三维数组实现三级联动效果
2017/02/07 Javascript
利用Jasmine对Angular进行单元测试的方法详解
2017/06/12 Javascript
使用MUI框架模拟手机端的下拉刷新和上拉加载功能
2017/09/04 Javascript
JavaScript实现的级联算法示例【省市二级联动功能】
2018/12/25 Javascript
微信小程序 wx:for遍历循环使用实例解析
2019/09/09 Javascript
NodeJs crypto加密制作token的实现代码
2019/11/15 NodeJs
原生js拖拽功能制作滑动条实例代码
2021/02/05 Javascript
python检测lvs real server状态
2014/01/22 Python
python实现的udp协议Server和Client代码实例
2014/06/04 Python
Python3爬虫学习之将爬取的信息保存到本地的方法详解
2018/12/12 Python
python 实现分页显示从es中获取的数据方法
2018/12/26 Python
python石头剪刀布小游戏(三局两胜制)
2021/01/20 Python
Python如何使用字符打印照片
2020/01/03 Python
matplotlib 画双轴子图无法显示x轴的解决方法
2020/07/27 Python
python实现批处理文件
2020/07/28 Python
文职个人求职信范文
2013/09/23 职场文书
学校七一活动方案
2014/01/19 职场文书
《悯农》教学反思
2014/04/28 职场文书
诉前财产保全担保书
2014/05/20 职场文书
学习教师敬业奉献模范事迹材料思想汇报
2014/09/19 职场文书
Mysql Innodb存储引擎之索引与算法
2022/02/15 MySQL