pytorch:实现简单的GAN示例(MNIST数据集)


Posted in Python onJanuary 10, 2020

我就废话不多说了,直接上代码吧!

# -*- coding: utf-8 -*-
"""
Created on Sat Oct 13 10:22:45 2018
@author: www
"""
 
import torch
from torch import nn
from torch.autograd import Variable
 
import torchvision.transforms as tfs
from torch.utils.data import DataLoader, sampler
from torchvision.datasets import MNIST
 
import numpy as np
 
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
 
plt.rcParams['figure.figsize'] = (10.0, 8.0) # 设置画图的尺寸
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
 
def show_images(images): # 定义画图工具
  images = np.reshape(images, [images.shape[0], -1])
  sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
  sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))
 
  fig = plt.figure(figsize=(sqrtn, sqrtn))
  gs = gridspec.GridSpec(sqrtn, sqrtn)
  gs.update(wspace=0.05, hspace=0.05)
 
  for i, img in enumerate(images):
    ax = plt.subplot(gs[i])
    plt.axis('off')
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_aspect('equal')
    plt.imshow(img.reshape([sqrtimg,sqrtimg]))
  return 
  
def preprocess_img(x):
  x = tfs.ToTensor()(x)
  return (x - 0.5) / 0.5
 
def deprocess_img(x):
  return (x + 1.0) / 2.0
 
class ChunkSampler(sampler.Sampler): # 定义一个取样的函数
  """Samples elements sequentially from some offset. 
  Arguments:
    num_samples: # of desired datapoints
    start: offset where we should start selecting from
  """
  def __init__(self, num_samples, start=0):
    self.num_samples = num_samples
    self.start = start
 
  def __iter__(self):
    return iter(range(self.start, self.start + self.num_samples))
 
  def __len__(self):
    return self.num_samples
    
NUM_TRAIN = 50000
NUM_VAL = 5000
 
NOISE_DIM = 96
batch_size = 128
 
train_set = MNIST('E:/data', train=True, transform=preprocess_img)
 
train_data = DataLoader(train_set, batch_size=batch_size, sampler=ChunkSampler(NUM_TRAIN, 0))
 
val_set = MNIST('E:/data', train=True, transform=preprocess_img)
 
val_data = DataLoader(val_set, batch_size=batch_size, sampler=ChunkSampler(NUM_VAL, NUM_TRAIN))
 
imgs = deprocess_img(train_data.__iter__().next()[0].view(batch_size, 784)).numpy().squeeze() # 可视化图片效果
show_images(imgs)
 
#判别网络
def discriminator():
  net = nn.Sequential(    
      nn.Linear(784, 256),
      nn.LeakyReLU(0.2),
      nn.Linear(256, 256),
      nn.LeakyReLU(0.2),
      nn.Linear(256, 1)
    )
  return net
  
#生成网络
def generator(noise_dim=NOISE_DIM):  
  net = nn.Sequential(
    nn.Linear(noise_dim, 1024),
    nn.ReLU(True),
    nn.Linear(1024, 1024),
    nn.ReLU(True),
    nn.Linear(1024, 784),
    nn.Tanh()
  )
  return net
  
#判别器的 loss 就是将真实数据的得分判断为 1,假的数据的得分判断为 0,而生成器的 loss 就是将假的数据判断为 1
 
bce_loss = nn.BCEWithLogitsLoss()#交叉熵损失函数
 
def discriminator_loss(logits_real, logits_fake): # 判别器的 loss
  size = logits_real.shape[0]
  true_labels = Variable(torch.ones(size, 1)).float()
  false_labels = Variable(torch.zeros(size, 1)).float()
  loss = bce_loss(logits_real, true_labels) + bce_loss(logits_fake, false_labels)
  return loss
  
def generator_loss(logits_fake): # 生成器的 loss 
  size = logits_fake.shape[0]
  true_labels = Variable(torch.ones(size, 1)).float()
  loss = bce_loss(logits_fake, true_labels)
  return loss
  
# 使用 adam 来进行训练,学习率是 3e-4, beta1 是 0.5, beta2 是 0.999
def get_optimizer(net):
  optimizer = torch.optim.Adam(net.parameters(), lr=3e-4, betas=(0.5, 0.999))
  return optimizer
  
def train_a_gan(D_net, G_net, D_optimizer, G_optimizer, discriminator_loss, generator_loss, show_every=250, 
        noise_size=96, num_epochs=10):
  iter_count = 0
  for epoch in range(num_epochs):
    for x, _ in train_data:
      bs = x.shape[0]
      # 判别网络
      real_data = Variable(x).view(bs, -1) # 真实数据
      logits_real = D_net(real_data) # 判别网络得分
      
      sample_noise = (torch.rand(bs, noise_size) - 0.5) / 0.5 # -1 ~ 1 的均匀分布
      g_fake_seed = Variable(sample_noise)
      fake_images = G_net(g_fake_seed) # 生成的假的数据
      logits_fake = D_net(fake_images) # 判别网络得分
 
      d_total_error = discriminator_loss(logits_real, logits_fake) # 判别器的 loss
      D_optimizer.zero_grad()
      d_total_error.backward()
      D_optimizer.step() # 优化判别网络
      
      # 生成网络
      g_fake_seed = Variable(sample_noise)
      fake_images = G_net(g_fake_seed) # 生成的假的数据
 
      gen_logits_fake = D_net(fake_images)
      g_error = generator_loss(gen_logits_fake) # 生成网络的 loss
      G_optimizer.zero_grad()
      g_error.backward()
      G_optimizer.step() # 优化生成网络
 
      if (iter_count % show_every == 0):
        print('Iter: {}, D: {:.4}, G:{:.4}'.format(iter_count, d_total_error.item(), g_error.item()))
        imgs_numpy = deprocess_img(fake_images.data.cpu().numpy())
        show_images(imgs_numpy[0:16])
        plt.show()
        print()
      iter_count += 1
 
D = discriminator()
G = generator()
 
D_optim = get_optimizer(D)
G_optim = get_optimizer(G)
 
train_a_gan(D, G, D_optim, G_optim, discriminator_loss, generator_loss)

以上这篇pytorch:实现简单的GAN示例(MNIST数据集)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python常规方法实现数组的全排列
Mar 17 Python
python实现决策树分类算法
Dec 21 Python
详解python 拆包可迭代数据如tuple, list
Dec 29 Python
Python实现的朴素贝叶斯分类器示例
Jan 06 Python
Python+matplotlib实现计算两个信号的交叉谱密度实例
Jan 08 Python
django允许外部访问的实例讲解
May 14 Python
3分钟学会一个Python小技巧
Nov 23 Python
Python中按键来获取指定的值
Mar 02 Python
pygame实现成语填空游戏
Oct 29 Python
pandas.DataFrame.drop_duplicates 用法介绍
Jul 06 Python
python之随机数函数的实现示例
Dec 30 Python
pycharm配置python 设置pip安装源为豆瓣源
Feb 05 Python
pytorch GAN生成对抗网络实例
Jan 10 #Python
解决pytorch报错:AssertionError: Invalid device id的问题
Jan 10 #Python
python3中关于excel追加写入格式被覆盖问题(实例代码)
Jan 10 #Python
mac使用python识别图形验证码功能
Jan 10 #Python
python列表推导和生成器表达式知识点总结
Jan 10 #Python
pytorch的梯度计算以及backward方法详解
Jan 10 #Python
Python如何获取Win7,Win10系统缩放大小
Jan 10 #Python
You might like
第二节--PHP5 的对象模型
2006/11/16 PHP
php数组函数序列之array_flip() 将数组键名与值对调
2011/11/07 PHP
PHP中设置时区,记录日志文件的实现代码
2013/01/07 PHP
PHP的mysqli_query参数MYSQLI_STORE_RESULT和MYSQLI_USE_RESULT的区别
2014/09/29 PHP
Javascript 面向对象编程(coolshell)
2012/03/18 Javascript
转换字符串为json对象的方法详解
2013/11/29 Javascript
BootStrap 智能表单实战系列(二)BootStrap支持的类型简介
2016/06/13 Javascript
js实现文字超出部分用省略号代替实例代码
2016/09/01 Javascript
Vue.js动态组件解析
2016/09/09 Javascript
使用vue.js实现联动效果的示例代码
2017/01/10 Javascript
Javascript中this关键字指向问题的测试与详解
2017/08/11 Javascript
angularJs 表格添加删除修改查询方法
2018/02/27 Javascript
微信小程序云开发之使用云存储
2019/05/17 Javascript
适合前端Vue开发童鞋的跨平台Weex的使用详解
2019/10/16 Javascript
微信小程序上传图片并等比列压缩到指定大小的实例代码
2019/10/24 Javascript
Vue实例的对象参数options的几个常用选项详解
2019/11/08 Javascript
vue仿淘宝滑动验证码功能(样式模仿)
2019/12/10 Javascript
JS三级联动代码格式实例详解
2019/12/30 Javascript
详解Node.JS模块 process
2020/08/31 Javascript
小议Python中自定义函数的可变参数的使用及注意点
2016/06/21 Python
浅谈Python爬取网页的编码处理
2016/11/04 Python
详解Python自建logging模块
2018/01/29 Python
python+opencv识别图片中的圆形
2020/03/25 Python
SELENIUM自动化模拟键盘快捷键操作实现解析
2019/10/28 Python
使用Tensorflow实现可视化中间层和卷积层
2020/01/24 Python
Python random库使用方法及异常处理方案
2020/03/02 Python
使用css3 属性如何丰富图片样式(圆角 阴影 渐变)
2012/11/22 HTML / CSS
墨尔本照明批发商店:Mica Lighting
2017/12/28 全球购物
东方通信股份有限公司VC面试题
2014/08/27 面试题
应聘医药代表职位求职信
2013/10/21 职场文书
大学生大二自我鉴定
2013/10/28 职场文书
验房委托书
2014/08/30 职场文书
资料员岗位职责范本
2015/04/13 职场文书
运动会表扬稿范文
2015/05/05 职场文书
详解PHP用mb_string处理windows中文字符
2021/05/26 PHP
netty 实现tomcat的示例代码
2022/06/05 Servers