Pytorch使用MNIST数据集实现基础GAN和DCGAN详解


Posted in Python onJanuary 10, 2020

原始生成对抗网络Generative Adversarial Networks GAN包含生成器Generator和判别器Discriminator,数据有真实数据groundtruth,还有需要网络生成的“fake”数据,目的是网络生成的fake数据可以“骗过”判别器,让判别器认不出来,就是让判别器分不清进入的数据是真实数据还是fake数据。总的来说是:判别器区分真实数据和fake数据的能力越强越好;生成器生成的数据骗过判别器的能力越强越好,这个是矛盾的,所以只能交替训练网络。

需要搭建生成器网络和判别器网络,训练的时候交替训练。

首先训练判别器的参数,固定生成器的参数,让判别器判断生成器生成的数据,让其和0接近,让判别器判断真实数据,让其和1接近;

接着训练生成器的参数,固定判别器的参数,让生成器生成的数据进入判别器,让判断结果和1接近。生成器生成数据需要给定随机初始值

线性版:

import torch
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
from torch import optim
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.gridspec as gridspec
 
def showimg(images,count):
 images=images.detach().numpy()[0:16,:]
 images=255*(0.5*images+0.5)
 images = images.astype(np.uint8)
 grid_length=int(np.ceil(np.sqrt(images.shape[0])))
 plt.figure(figsize=(4,4))
 width = int(np.sqrt((images.shape[1])))
 gs = gridspec.GridSpec(grid_length,grid_length,wspace=0,hspace=0)
 # gs.update(wspace=0, hspace=0)
 print('starting...')
 for i, img in enumerate(images):
 ax = plt.subplot(gs[i])
 ax.set_xticklabels([])
 ax.set_yticklabels([])
 ax.set_aspect('equal')
 plt.imshow(img.reshape([width,width]),cmap = plt.cm.gray)
 plt.axis('off')
 plt.tight_layout()
 print('showing...')
 plt.tight_layout()
 plt.savefig('./GAN_Image/%d.png'%count, bbox_inches='tight')
 
def loadMNIST(batch_size): #MNIST图片的大小是28*28
 trans_img=transforms.Compose([transforms.ToTensor()])
 trainset=MNIST('./data',train=True,transform=trans_img,download=True)
 testset=MNIST('./data',train=False,transform=trans_img,download=True)
 # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 trainloader=DataLoader(trainset,batch_size=batch_size,shuffle=True,num_workers=10)
 testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=10)
 return trainset,testset,trainloader,testloader
 
class discriminator(nn.Module):
 def __init__(self):
 super(discriminator,self).__init__()
 self.dis=nn.Sequential(
  nn.Linear(784,300),
  nn.LeakyReLU(0.2),
  nn.Linear(300,150),
  nn.LeakyReLU(0.2),
  nn.Linear(150,1),
  nn.Sigmoid()
 )
 def forward(self, x):
 x=self.dis(x)
 return x
 
class generator(nn.Module):
 def __init__(self,input_size):
 super(generator,self).__init__()
 self.gen=nn.Sequential(
  nn.Linear(input_size,150),
  nn.ReLU(True),
  nn.Linear(150,300),
  nn.ReLU(True),
  nn.Linear(300,784),
  nn.Tanh()
 )
 def forward(self, x):
 x=self.gen(x)
 return x
 
if __name__=="__main__":
 criterion=nn.BCELoss()
 num_img=100
 z_dimension=100
 D=discriminator()
 G=generator(z_dimension)
 trainset, testset, trainloader, testloader = loadMNIST(num_img) # data
 d_optimizer=optim.Adam(D.parameters(),lr=0.0003)
 g_optimizer=optim.Adam(G.parameters(),lr=0.0003)
 '''
 交替训练的方式训练网络
 先训练判别器网络D再训练生成器网络G
 不同网络的训练次数是超参数
 也可以两个网络训练相同的次数
 这样就可以不用分别训练两个网络
 '''
 count=0
 #鉴别器D的训练,固定G的参数
 epoch = 100
 gepoch = 1
 for i in range(epoch):
 for (img, label) in trainloader:
  # num_img=img.size()[0]
  real_img=img.view(num_img,-1)#展开为28*28=784
  real_label=torch.ones(num_img)#真实label为1
  fake_label=torch.zeros(num_img)#假的label为0
 
  #compute loss of real_img
  real_out=D(real_img) #真实图片送入判别器D输出0~1
  d_loss_real=criterion(real_out,real_label)#得到loss
  real_scores=real_out#真实图片放入判别器输出越接近1越好
 
  #compute loss of fake_img
  z=torch.randn(num_img,z_dimension)#随机生成向量
  fake_img=G(z)#将向量放入生成网络G生成一张图片
  fake_out=D(fake_img)#判别器判断假的图片
  d_loss_fake=criterion(fake_out,fake_label)#假的图片的loss
  fake_scores=fake_out#假的图片放入判别器输出越接近0越好
 
  #D bp and optimize
  d_loss=d_loss_real+d_loss_fake
  d_optimizer.zero_grad() #判别器D的梯度归零
  d_loss.backward() #反向传播
  d_optimizer.step() #更新判别器D参数
 
  #生成器G的训练compute loss of fake_img
  for j in range(gepoch):
  fake_label = torch.ones(num_img) # 真实label为1
  z = torch.randn(num_img, z_dimension) # 随机生成向量
  fake_img = G(z) # 将向量放入生成网络G生成一张图片
  output = D(fake_img) # 经过判别器得到结果
  g_loss = criterion(output, fake_label)#得到假的图片与真实标签的loss
  #bp and optimize
  g_optimizer.zero_grad() #生成器G的梯度归零
  g_loss.backward() #反向传播
  g_optimizer.step()#更新生成器G参数
 print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f} '
   'D real: {:.6f}, D fake: {:.6f}'.format(
  i, epoch, d_loss.data[0], g_loss.data[0],
  real_scores.data.mean(), fake_scores.data.mean()))
 showimg(fake_img,count)
 # plt.show()
 count += 1

这里的图分别是 epoch为0、50、100、150、190的运行结果,可以看到图片中的数字并不单一

Pytorch使用MNIST数据集实现基础GAN和DCGAN详解

卷积版 Deep Convolutional Generative Adversarial Networks:

import torch
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
from torch import optim
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
from torch.autograd import Variable
 
import matplotlib.gridspec as gridspec
import os
 
def showimg(images,count):
 images=images.to('cpu')
 images=images.detach().numpy()
 images=images[[6, 12, 18, 24, 30, 36, 42, 48, 54, 60, 66, 72, 78, 84, 90, 96]]
 images=255*(0.5*images+0.5)
 images = images.astype(np.uint8)
 grid_length=int(np.ceil(np.sqrt(images.shape[0])))
 plt.figure(figsize=(4,4))
 width = images.shape[2]
 gs = gridspec.GridSpec(grid_length,grid_length,wspace=0,hspace=0)
 print(images.shape)
 for i, img in enumerate(images):
 ax = plt.subplot(gs[i])
 ax.set_xticklabels([])
 ax.set_yticklabels([])
 ax.set_aspect('equal')
 plt.imshow(img.reshape(width,width),cmap = plt.cm.gray)
 plt.axis('off')
 plt.tight_layout()
# print('showing...')
 plt.tight_layout()
# plt.savefig('./GAN_Imaget/%d.png'%count, bbox_inches='tight')
 
def loadMNIST(batch_size): #MNIST图片的大小是28*28
 trans_img=transforms.Compose([transforms.ToTensor()])
 trainset=MNIST('./data',train=True,transform=trans_img,download=True)
 testset=MNIST('./data',train=False,transform=trans_img,download=True)
 # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 trainloader=DataLoader(trainset,batch_size=batch_size,shuffle=True,num_workers=10)
 testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=10)
 return trainset,testset,trainloader,testloader
 
class discriminator(nn.Module):
 def __init__(self):
 super(discriminator,self).__init__()
 self.dis=nn.Sequential(
  nn.Conv2d(1,32,5,stride=1,padding=2),
  nn.LeakyReLU(0.2,True),
  nn.MaxPool2d((2,2)),
 
  nn.Conv2d(32,64,5,stride=1,padding=2),
  nn.LeakyReLU(0.2,True),
  nn.MaxPool2d((2,2))
 )
 self.fc=nn.Sequential(
  nn.Linear(7 * 7 * 64, 1024),
  nn.LeakyReLU(0.2, True),
  nn.Linear(1024, 1),
  nn.Sigmoid()
 )
 def forward(self, x):
 x=self.dis(x)
 x=x.view(x.size(0),-1)
 x=self.fc(x)
 return x
 
class generator(nn.Module):
 def __init__(self,input_size,num_feature):
 super(generator,self).__init__()
 self.fc=nn.Linear(input_size,num_feature) #1*56*56
 self.br=nn.Sequential(
  nn.BatchNorm2d(1),
  nn.ReLU(True)
 )
 self.gen=nn.Sequential(
  nn.Conv2d(1,50,3,stride=1,padding=1),
  nn.BatchNorm2d(50),
  nn.ReLU(True),
 
  nn.Conv2d(50,25,3,stride=1,padding=1),
  nn.BatchNorm2d(25),
  nn.ReLU(True),
 
  nn.Conv2d(25,1,2,stride=2),
  nn.Tanh()
 )
 def forward(self, x):
 x=self.fc(x)
 x=x.view(x.size(0),1,56,56)
 x=self.br(x)
 x=self.gen(x)
 return x
 
if __name__=="__main__":
 criterion=nn.BCELoss()
 num_img=100
 z_dimension=100
 D=discriminator()
 G=generator(z_dimension,3136) #1*56*56
 trainset, testset, trainloader, testloader = loadMNIST(num_img) # data
 D=D.cuda()
 G=G.cuda()
 d_optimizer=optim.Adam(D.parameters(),lr=0.0003)
 g_optimizer=optim.Adam(G.parameters(),lr=0.0003)
 '''
 交替训练的方式训练网络
 先训练判别器网络D再训练生成器网络G
 不同网络的训练次数是超参数
 也可以两个网络训练相同的次数,
 这样就可以不用分别训练两个网络
 '''
 count=0
 #鉴别器D的训练,固定G的参数
 epoch = 100
 gepoch = 1
 for i in range(epoch):
 for (img, label) in trainloader:
  # num_img=img.size()[0]
  img=Variable(img).cuda()
  real_label=Variable(torch.ones(num_img)).cuda()#真实label为1
  fake_label=Variable(torch.zeros(num_img)).cuda()#假的label为0
 
  #compute loss of real_img
  real_out=D(img) #真实图片送入判别器D输出0~1
  d_loss_real=criterion(real_out,real_label)#得到loss
  real_scores=real_out#真实图片放入判别器输出越接近1越好
 
  #compute loss of fake_img
  z=Variable(torch.randn(num_img,z_dimension)).cuda()#随机生成向量
  fake_img=G(z)#将向量放入生成网络G生成一张图片
  fake_out=D(fake_img)#判别器判断假的图片
  d_loss_fake=criterion(fake_out,fake_label)#假的图片的loss
  fake_scores=fake_out#假的图片放入判别器输出越接近0越好
 
  #D bp and optimize
  d_loss=d_loss_real+d_loss_fake
  d_optimizer.zero_grad() #判别器D的梯度归零
  d_loss.backward() #反向传播
  d_optimizer.step() #更新判别器D参数
 
  #生成器G的训练compute loss of fake_img
  for j in range(gepoch):
  fake_label = Variable(torch.ones(num_img)).cuda() # 真实label为1
  z = Variable(torch.randn(num_img, z_dimension)).cuda() # 随机生成向量
  fake_img = G(z) # 将向量放入生成网络G生成一张图片
  output = D(fake_img) # 经过判别器得到结果
  g_loss = criterion(output, fake_label)#得到假的图片与真实标签的loss
  #bp and optimize
  g_optimizer.zero_grad() #生成器G的梯度归零
  g_loss.backward() #反向传播
  g_optimizer.step()#更新生成器G参数
  # if ((i+1)%1000==0):
  # print("[%d/%d] GLoss: %.5f" % (i + 1, gepoch, g_loss.data[0]))
 print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f} '
   'D real: {:.6f}, D fake: {:.6f}'.format(
  i, epoch, d_loss.data[0], g_loss.data[0],
  real_scores.data.mean(), fake_scores.data.mean()))
 showimg(fake_img,count)
 plt.show()
 count += 1

这里的gepoch设置为1,运行39次的结果是:

Pytorch使用MNIST数据集实现基础GAN和DCGAN详解

gepoch设置为2,运行0、25、50、75、100次的结果是:

Pytorch使用MNIST数据集实现基础GAN和DCGAN详解

gepoch设置为3,运行25、50、75次的结果是:

Pytorch使用MNIST数据集实现基础GAN和DCGAN详解

gepoch设置为4,运行0、10、20、30、35次的结果是:

Pytorch使用MNIST数据集实现基础GAN和DCGAN详解

gepoch设置为5,运行0、10、20、25、29次的结果是:

Pytorch使用MNIST数据集实现基础GAN和DCGAN详解

gepoch设置为3,z_dimension设置为190,epoch运行0、10、15、20、25、35的结果是:

Pytorch使用MNIST数据集实现基础GAN和DCGAN详解

可以看到生成的数字基本没有太多的规律,可能最终都是同个数字,不能生成指定的数字,CGAN就很好的解决这个问题,可以生成指定的数字 Pytorch使用MNIST数据集实现CGAN和生成指定的数字方式

以上这篇Pytorch使用MNIST数据集实现基础GAN和DCGAN详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现多行注释的另类方法
Aug 22 Python
python中numpy.zeros(np.zeros)的使用方法
Nov 07 Python
Python编程实现蚁群算法详解
Nov 13 Python
python2 与 python3 实现共存的方法
Jul 12 Python
Python开启线程,在函数中开线程的实例
Feb 22 Python
浅析Python与Mongodb数据库之间的操作方法
Jul 01 Python
浅谈在django中使用filter()(即对QuerySet操作)时踩的坑
Mar 31 Python
解决c++调用python中文乱码问题
Jul 29 Python
python语音识别指南终极版(有这一篇足矣)
Sep 09 Python
Python操作Excel的学习笔记
Feb 18 Python
python 基于pygame实现俄罗斯方块
Mar 02 Python
Opencv实现二维直方图的计算及绘制
Jul 21 Python
Pytorch使用MNIST数据集实现CGAN和生成指定的数字方式
Jan 10 #Python
pytorch实现mnist分类的示例讲解
Jan 10 #Python
pytorch 准备、训练和测试自己的图片数据的方法
Jan 10 #Python
pytorch GAN伪造手写体mnist数据集方式
Jan 10 #Python
MNIST数据集转化为二维图片的实现示例
Jan 10 #Python
pytorch:实现简单的GAN示例(MNIST数据集)
Jan 10 #Python
pytorch GAN生成对抗网络实例
Jan 10 #Python
You might like
杏林同学录(三)
2006/10/09 PHP
ThinkPHP3.1基础知识快速入门
2014/06/19 PHP
改写ThinkPHP的U方法使其路由下分页正常
2014/07/02 PHP
php类自动装载、链式操作、魔术方法实现代码
2017/07/23 PHP
基于ThinkPHP5.0实现图片上传插件
2017/09/25 PHP
window.parent调用父框架时 ie跟火狐不兼容问题
2009/07/30 Javascript
json 实例详细说明教程
2009/10/31 Javascript
javascript:history.go()和History.back()的区别及应用
2012/11/25 Javascript
jQuery不间断滚动效果(模拟百度新闻支持文字/图片/垂直滚动)
2013/02/05 Javascript
JavaScript实现的日期控件具体代码
2013/11/18 Javascript
JavaScript获取图片的原始尺寸以宽度为例
2014/05/04 Javascript
JavaScript加入收藏夹功能(兼容IE、firefox、chrome)
2014/05/05 Javascript
node.js中的fs.readdirSync方法使用说明
2014/12/17 Javascript
jQuery实用技巧必备(下)
2015/11/03 Javascript
解决jquery插件:TypeError:$.browser is undefined报错的方法
2015/11/21 Javascript
jQuery实现的自定义滚动条实例详解
2016/09/20 Javascript
javascript循环链表之约瑟夫环的实现方法
2017/01/16 Javascript
微信小程序实现类似微信点击语音播放效果
2020/03/30 Javascript
深入浅析ng-bootstrap 组件集中 tabset 组件的实现分析
2019/07/19 Javascript
JavaScript实现串行请求的示例代码
2020/09/14 Javascript
[52:36]VGJ.S vs Serenity 2018国际邀请赛小组赛BO2 第一场 8.19
2018/08/21 DOTA
用Python实现web端用户登录和注册功能的教程
2015/04/30 Python
TensorFlow深度学习之卷积神经网络CNN
2018/03/09 Python
Python3用tkinter和PIL实现看图工具
2018/06/21 Python
Python爬虫爬取新浪微博内容示例【基于代理IP】
2018/08/03 Python
python绘制多个曲线的折线图
2020/03/23 Python
python做接口测试的必要性
2019/11/20 Python
解决pycharm安装第三方库失败的问题
2020/05/09 Python
django实现日志按日期分割
2020/05/21 Python
CSS3 animation ? steps 函数详解
2019/08/30 HTML / CSS
全球采购的街头服饰和帽子:Urban Excess
2020/10/28 全球购物
优秀的应届生自荐信
2014/05/23 职场文书
工作试用期自我评价
2015/03/10 职场文书
Mysql数据库索引面试题(程序员基础技能)
2021/05/31 MySQL
Python中异常处理用法
2021/11/27 Python
Java实现扫雷游戏详细代码讲解
2022/05/25 Java/Android