pytorch:实现简单的GAN示例(MNIST数据集)


Posted in Python onJanuary 10, 2020

我就废话不多说了,直接上代码吧!

# -*- coding: utf-8 -*-
"""
Created on Sat Oct 13 10:22:45 2018
@author: www
"""
 
import torch
from torch import nn
from torch.autograd import Variable
 
import torchvision.transforms as tfs
from torch.utils.data import DataLoader, sampler
from torchvision.datasets import MNIST
 
import numpy as np
 
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
 
plt.rcParams['figure.figsize'] = (10.0, 8.0) # 设置画图的尺寸
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
 
def show_images(images): # 定义画图工具
  images = np.reshape(images, [images.shape[0], -1])
  sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
  sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))
 
  fig = plt.figure(figsize=(sqrtn, sqrtn))
  gs = gridspec.GridSpec(sqrtn, sqrtn)
  gs.update(wspace=0.05, hspace=0.05)
 
  for i, img in enumerate(images):
    ax = plt.subplot(gs[i])
    plt.axis('off')
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_aspect('equal')
    plt.imshow(img.reshape([sqrtimg,sqrtimg]))
  return 
  
def preprocess_img(x):
  x = tfs.ToTensor()(x)
  return (x - 0.5) / 0.5
 
def deprocess_img(x):
  return (x + 1.0) / 2.0
 
class ChunkSampler(sampler.Sampler): # 定义一个取样的函数
  """Samples elements sequentially from some offset. 
  Arguments:
    num_samples: # of desired datapoints
    start: offset where we should start selecting from
  """
  def __init__(self, num_samples, start=0):
    self.num_samples = num_samples
    self.start = start
 
  def __iter__(self):
    return iter(range(self.start, self.start + self.num_samples))
 
  def __len__(self):
    return self.num_samples
    
NUM_TRAIN = 50000
NUM_VAL = 5000
 
NOISE_DIM = 96
batch_size = 128
 
train_set = MNIST('E:/data', train=True, transform=preprocess_img)
 
train_data = DataLoader(train_set, batch_size=batch_size, sampler=ChunkSampler(NUM_TRAIN, 0))
 
val_set = MNIST('E:/data', train=True, transform=preprocess_img)
 
val_data = DataLoader(val_set, batch_size=batch_size, sampler=ChunkSampler(NUM_VAL, NUM_TRAIN))
 
imgs = deprocess_img(train_data.__iter__().next()[0].view(batch_size, 784)).numpy().squeeze() # 可视化图片效果
show_images(imgs)
 
#判别网络
def discriminator():
  net = nn.Sequential(    
      nn.Linear(784, 256),
      nn.LeakyReLU(0.2),
      nn.Linear(256, 256),
      nn.LeakyReLU(0.2),
      nn.Linear(256, 1)
    )
  return net
  
#生成网络
def generator(noise_dim=NOISE_DIM):  
  net = nn.Sequential(
    nn.Linear(noise_dim, 1024),
    nn.ReLU(True),
    nn.Linear(1024, 1024),
    nn.ReLU(True),
    nn.Linear(1024, 784),
    nn.Tanh()
  )
  return net
  
#判别器的 loss 就是将真实数据的得分判断为 1,假的数据的得分判断为 0,而生成器的 loss 就是将假的数据判断为 1
 
bce_loss = nn.BCEWithLogitsLoss()#交叉熵损失函数
 
def discriminator_loss(logits_real, logits_fake): # 判别器的 loss
  size = logits_real.shape[0]
  true_labels = Variable(torch.ones(size, 1)).float()
  false_labels = Variable(torch.zeros(size, 1)).float()
  loss = bce_loss(logits_real, true_labels) + bce_loss(logits_fake, false_labels)
  return loss
  
def generator_loss(logits_fake): # 生成器的 loss 
  size = logits_fake.shape[0]
  true_labels = Variable(torch.ones(size, 1)).float()
  loss = bce_loss(logits_fake, true_labels)
  return loss
  
# 使用 adam 来进行训练,学习率是 3e-4, beta1 是 0.5, beta2 是 0.999
def get_optimizer(net):
  optimizer = torch.optim.Adam(net.parameters(), lr=3e-4, betas=(0.5, 0.999))
  return optimizer
  
def train_a_gan(D_net, G_net, D_optimizer, G_optimizer, discriminator_loss, generator_loss, show_every=250, 
        noise_size=96, num_epochs=10):
  iter_count = 0
  for epoch in range(num_epochs):
    for x, _ in train_data:
      bs = x.shape[0]
      # 判别网络
      real_data = Variable(x).view(bs, -1) # 真实数据
      logits_real = D_net(real_data) # 判别网络得分
      
      sample_noise = (torch.rand(bs, noise_size) - 0.5) / 0.5 # -1 ~ 1 的均匀分布
      g_fake_seed = Variable(sample_noise)
      fake_images = G_net(g_fake_seed) # 生成的假的数据
      logits_fake = D_net(fake_images) # 判别网络得分
 
      d_total_error = discriminator_loss(logits_real, logits_fake) # 判别器的 loss
      D_optimizer.zero_grad()
      d_total_error.backward()
      D_optimizer.step() # 优化判别网络
      
      # 生成网络
      g_fake_seed = Variable(sample_noise)
      fake_images = G_net(g_fake_seed) # 生成的假的数据
 
      gen_logits_fake = D_net(fake_images)
      g_error = generator_loss(gen_logits_fake) # 生成网络的 loss
      G_optimizer.zero_grad()
      g_error.backward()
      G_optimizer.step() # 优化生成网络
 
      if (iter_count % show_every == 0):
        print('Iter: {}, D: {:.4}, G:{:.4}'.format(iter_count, d_total_error.item(), g_error.item()))
        imgs_numpy = deprocess_img(fake_images.data.cpu().numpy())
        show_images(imgs_numpy[0:16])
        plt.show()
        print()
      iter_count += 1
 
D = discriminator()
G = generator()
 
D_optim = get_optimizer(D)
G_optim = get_optimizer(G)
 
train_a_gan(D, G, D_optim, G_optim, discriminator_loss, generator_loss)

以上这篇pytorch:实现简单的GAN示例(MNIST数据集)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python最长公共子串算法实例
Mar 07 Python
Django中对通过测试的用户进行限制访问的方法
Jul 23 Python
Python3学习笔记之列表方法示例详解
Oct 06 Python
对Python中range()函数和list的比较
Apr 19 Python
Python实现动态添加属性和方法操作示例
Jul 25 Python
Python读取mat文件,并保存为pickle格式的方法
Oct 23 Python
python使用递归的方式建立二叉树
Jul 03 Python
Python 利用邮件系统完成远程控制电脑的实现(关机、重启等)
Nov 19 Python
wxPython实现画图板
Aug 27 Python
pandas 强制类型转换 df.astype实例
Apr 09 Python
Python如何获取文件路径/目录
Sep 22 Python
Python截图并保存的具体实例
Jan 14 Python
pytorch GAN生成对抗网络实例
Jan 10 #Python
解决pytorch报错:AssertionError: Invalid device id的问题
Jan 10 #Python
python3中关于excel追加写入格式被覆盖问题(实例代码)
Jan 10 #Python
mac使用python识别图形验证码功能
Jan 10 #Python
python列表推导和生成器表达式知识点总结
Jan 10 #Python
pytorch的梯度计算以及backward方法详解
Jan 10 #Python
Python如何获取Win7,Win10系统缩放大小
Jan 10 #Python
You might like
php动态添加url查询参数的方法
2015/04/14 PHP
Laravel使用Caching缓存数据减轻数据库查询压力的方法
2016/03/15 PHP
Fleaphp常见函数功能与用法示例
2016/11/15 PHP
php抽象方法和普通方法的区别点总结
2019/10/13 PHP
Javascript笔记一 js以及json基础使用说明
2010/05/22 Javascript
IE6/7 and IE8/9/10(IE7模式)依次隐藏具有absolute或relative的父元素和子元素后再显示父元素
2011/07/31 Javascript
Nodejs中自定义事件实例
2014/06/20 NodeJs
限制只能输入数字的实现代码
2016/05/16 Javascript
Bootstrap Modal对话框如何在关闭时触发事件
2016/12/02 Javascript
学习使用jQuery表单验证插件和日历插件
2017/02/13 Javascript
vue mintui-Loadmore结合实现下拉刷新和上拉加载示例
2017/10/12 Javascript
浅谈Angular 中何时取消订阅
2017/11/22 Javascript
vue脚手架搭建项目的兼容性配置详解
2018/07/17 Javascript
nodejs之koa2请求示例(GET,POST)
2018/08/07 NodeJs
VUE+Element UI实现简单的表格行内编辑效果的示例的代码
2018/10/31 Javascript
对Layer弹窗使用及返回数据接收的实例详解
2019/09/26 Javascript
如何在node环境实现“get数据解析”代码实例
2020/07/03 Javascript
微信小程序中data-key属性之数据传输(经验总结)
2020/08/22 Javascript
用Javascript实现发送短信验证码间隔功能
2021/02/08 Javascript
Python批量更改文件名的实现方法
2017/10/29 Python
python验证码识别实例代码
2018/02/03 Python
python实现可视化动态CPU性能监控
2018/06/21 Python
对django中render()与render_to_response()的区别详解
2018/10/16 Python
Python 实现取矩阵的部分列,保存为一个新的矩阵方法
2018/11/14 Python
python自定义时钟类、定时任务类
2021/02/22 Python
python合并多个excel文件的示例
2020/09/23 Python
python基于exchange函数发送邮件过程详解
2020/11/06 Python
美国眼镜网站:LensCrafters
2020/01/19 全球购物
Traffic People官网:女式花裙、上衣和连身裤
2020/10/12 全球购物
奖学金自我鉴定范文
2013/10/03 职场文书
家长给孩子的表扬信
2014/01/17 职场文书
银行党员批评与自我批评
2014/10/15 职场文书
2014学生会工作总结报告
2014/12/02 职场文书
2014年英语工作总结
2014/12/20 职场文书
高中家长意见怎么写
2015/06/03 职场文书
SONY AN-LP1 短波有源天线放大器图
2022/04/05 无线电