Pytorch实现WGAN用于动漫头像生成


Posted in Python onMarch 04, 2021

WGAN与GAN的不同

  • 去除sigmoid
  • 使用具有动量的优化方法,比如使用RMSProp
  • 要对Discriminator的权重做修整限制以确保lipschitz连续约

WGAN实战卷积生成动漫头像 

import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import os
from anime_face_generator.dataset import ImageDataset
 
batch_size = 32
num_epoch = 100
z_dimension = 100
dir_path = './wgan_img'
 
# 创建文件夹
if not os.path.exists(dir_path):
  os.mkdir(dir_path)
 
 
def to_img(x):
  """因为我们在生成器里面用了tanh"""
  out = 0.5 * (x + 1)
  return out
 
 
dataset = ImageDataset()
dataloader = DataLoader(dataset, batch_size=32, shuffle=False)
 
 
class Generator(nn.Module):
  def __init__(self):
    super().__init__()
 
    self.gen = nn.Sequential(
      # 输入是一个nz维度的噪声,我们可以认为它是一个1*1*nz的feature map
      nn.ConvTranspose2d(100, 512, 4, 1, 0, bias=False),
      nn.BatchNorm2d(512),
      nn.ReLU(True),
      # 上一步的输出形状:(512) x 4 x 4
      nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
      nn.BatchNorm2d(256),
      nn.ReLU(True),
      # 上一步的输出形状: (256) x 8 x 8
      nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
      nn.BatchNorm2d(128),
      nn.ReLU(True),
      # 上一步的输出形状: (256) x 16 x 16
      nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
      nn.BatchNorm2d(64),
      nn.ReLU(True),
      # 上一步的输出形状:(256) x 32 x 32
      nn.ConvTranspose2d(64, 3, 5, 3, 1, bias=False),
      nn.Tanh() # 输出范围 -1~1 故而采用Tanh
      # nn.Sigmoid()
      # 输出形状:3 x 96 x 96
    )
 
  def forward(self, x):
    x = self.gen(x)
    return x
 
  def weight_init(m):
    # weight_initialization: important for wgan
    class_name = m.__class__.__name__
    if class_name.find('Conv') != -1:
      m.weight.data.normal_(0, 0.02)
    elif class_name.find('Norm') != -1:
      m.weight.data.normal_(1.0, 0.02)
 
 
class Discriminator(nn.Module):
  def __init__(self):
    super().__init__()
    self.dis = nn.Sequential(
      nn.Conv2d(3, 64, 5, 3, 1, bias=False),
      nn.LeakyReLU(0.2, inplace=True),
      # 输出 (64) x 32 x 32
 
      nn.Conv2d(64, 128, 4, 2, 1, bias=False),
      nn.BatchNorm2d(128),
      nn.LeakyReLU(0.2, inplace=True),
      # 输出 (128) x 16 x 16
 
      nn.Conv2d(128, 256, 4, 2, 1, bias=False),
      nn.BatchNorm2d(256),
      nn.LeakyReLU(0.2, inplace=True),
      # 输出 (256) x 8 x 8
 
      nn.Conv2d(256, 512, 4, 2, 1, bias=False),
      nn.BatchNorm2d(512),
      nn.LeakyReLU(0.2, inplace=True),
      # 输出 (512) x 4 x 4
 
      nn.Conv2d(512, 1, 4, 1, 0, bias=False),
      nn.Flatten(),
      # nn.Sigmoid() # 输出一个数(概率)
    )
 
  def forward(self, x):
    x = self.dis(x)
    return x
 
  def weight_init(m):
    # weight_initialization: important for wgan
    class_name = m.__class__.__name__
    if class_name.find('Conv') != -1:
      m.weight.data.normal_(0, 0.02)
    elif class_name.find('Norm') != -1:
      m.weight.data.normal_(1.0, 0.02)
 
 
def save(model, filename="model.pt", out_dir="out/"):
  if model is not None:
    if not os.path.exists(out_dir):
      os.mkdir(out_dir)
    torch.save({'model': model.state_dict()}, out_dir + filename)
  else:
    print("[ERROR]:Please build a model!!!")
 
 
import QuickModelBuilder as builder
 
if __name__ == '__main__':
  one = torch.FloatTensor([1]).cuda()
  mone = -1 * one
 
  is_print = True
  # 创建对象
  D = Discriminator()
  G = Generator()
  D.weight_init()
  G.weight_init()
 
  if torch.cuda.is_available():
    D = D.cuda()
    G = G.cuda()
 
  lr = 2e-4
  d_optimizer = torch.optim.RMSprop(D.parameters(), lr=lr, )
  g_optimizer = torch.optim.RMSprop(G.parameters(), lr=lr, )
  d_scheduler = torch.optim.lr_scheduler.ExponentialLR(d_optimizer, gamma=0.99)
  g_scheduler = torch.optim.lr_scheduler.ExponentialLR(g_optimizer, gamma=0.99)
 
  fake_img = None
 
  # ##########################进入训练##判别器的判断过程#####################
  for epoch in range(num_epoch): # 进行多个epoch的训练
    pbar = builder.MyTqdm(epoch=epoch, maxval=len(dataloader))
    for i, img in enumerate(dataloader):
      num_img = img.size(0)
      real_img = img.cuda() # 将tensor变成Variable放入计算图中
      # 这里的优化器是D的优化器
      for param in D.parameters():
        param.requires_grad = True
      # ########判别器训练train#####################
      # 分为两部分:1、真的图像判别为真;2、假的图像判别为假
 
      # 计算真实图片的损失
      d_optimizer.zero_grad() # 在反向传播之前,先将梯度归0
      real_out = D(real_img) # 将真实图片放入判别器中
      d_loss_real = real_out.mean(0).view(1)
      d_loss_real.backward(one)
 
      # 计算生成图片的损失
      z = torch.randn(num_img, z_dimension).cuda() # 随机生成一些噪声
      z = z.reshape(num_img, z_dimension, 1, 1)
      fake_img = G(z).detach() # 随机噪声放入生成网络中,生成一张假的图片。 # 避免梯度传到G,因为G不用更新, detach分离
      fake_out = D(fake_img) # 判别器判断假的图片,
      d_loss_fake = fake_out.mean(0).view(1)
      d_loss_fake.backward(mone)
 
      d_loss = d_loss_fake - d_loss_real
      d_optimizer.step() # 更新参数
 
      # 每次更新判别器的参数之后把它们的绝对值截断到不超过一个固定常数c=0.01
      for parm in D.parameters():
        parm.data.clamp_(-0.01, 0.01)
 
      # ==================训练生成器============================
      # ###############################生成网络的训练###############################
      for param in D.parameters():
        param.requires_grad = False
 
      # 这里的优化器是G的优化器,所以不需要冻结D的梯度,因为不是D的优化器,不会更新D
      g_optimizer.zero_grad() # 梯度归0
 
      z = torch.randn(num_img, z_dimension).cuda()
      z = z.reshape(num_img, z_dimension, 1, 1)
      fake_img = G(z) # 随机噪声输入到生成器中,得到一副假的图片
      output = D(fake_img) # 经过判别器得到的结果
      # g_loss = criterion(output, real_label) # 得到的假的图片与真实的图片的label的loss
      g_loss = torch.mean(output).view(1)
      # bp and optimize
      g_loss.backward(one) # 进行反向传播
      g_optimizer.step() # .step()一般用在反向传播后面,用于更新生成网络的参数
 
      # 打印中间的损失
      pbar.set_right_info(d_loss=d_loss.data.item(),
                g_loss=g_loss.data.item(),
                real_scores=real_out.data.mean().item(),
                fake_scores=fake_out.data.mean().item(),
                )
      pbar.update()
      try:
        fake_images = to_img(fake_img.cpu())
        save_image(fake_images, dir_path + '/fake_images-{}.png'.format(epoch + 1))
      except:
        pass
      if is_print:
        is_print = False
        real_images = to_img(real_img.cpu())
        save_image(real_images, dir_path + '/real_images.png')
    pbar.finish()
    d_scheduler.step()
    g_scheduler.step()
    save(D, "wgan_D.pt")
    save(G, "wgan_G.pt")

到此这篇关于Pytorch实现WGAN用于动漫头像生成的文章就介绍到这了,更多相关Pytorch实现WGAN用于动漫头像生成内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
浅谈Python单向链表的实现
Dec 24 Python
编写Python爬虫抓取豆瓣电影TOP100及用户头像的方法
Jan 20 Python
python flask 多对多表查询功能
Jun 25 Python
Python编程给numpy矩阵添加一列方法示例
Dec 04 Python
tensorflow 获取变量&打印权值的实例讲解
Jun 14 Python
Python3爬虫全国地址信息
Jan 05 Python
Tensorflow设置显存自适应,显存比例的操作
Feb 03 Python
Django CSRF认证的几种解决方案
Mar 03 Python
Python读入mnist二进制图像文件并显示实例
Apr 24 Python
Python中内建模块collections如何使用
May 27 Python
详解用Python把PDF转为Word方法总结
Apr 27 Python
这样写python注释让代码更加的优雅
Jun 02 Python
基于PyInstaller各参数的含义说明
Mar 04 #Python
解决Pyinstaller打包软件失败的一个坑
Mar 04 #Python
selenium+python自动化78-autoit参数化与批量上传功能的实现
Mar 04 #Python
解决PDF 转图片时丢文字的一种可能方式
Mar 04 #Python
pandas数据分组groupby()和统计函数agg()的使用
Mar 04 #Python
pyx文件 生成pyd 文件用于 cython调用的实现
Mar 04 #Python
Python .py生成.pyd文件并打包.exe 的注意事项说明
Mar 04 #Python
You might like
PHP 分页原理分析,大家可以看看
2009/12/21 PHP
php构造函数实例讲解
2013/11/13 PHP
ZendFramework框架实现连接两个或多个数据库的方法
2016/12/08 PHP
JavaScript中几种常见排序算法小结
2011/02/22 Javascript
dreamweaver 安装Jquery智能提示
2011/04/02 Javascript
jquery限制输入字数,并提示剩余字数实现代码
2012/12/24 Javascript
jQuery实现用户注册的表单验证示例
2013/08/28 Javascript
纯Javascript实现Windows 8 Metro风格实现
2013/10/15 Javascript
jQuery实现获取table表格第一列值的方法
2016/03/01 Javascript
基于jQuery实现收缩展开功能
2016/03/18 Javascript
使用JS读取XML文件的方法
2016/11/25 Javascript
js实现鼠标跟随运动效果
2020/08/02 Javascript
nodejs调取微信收货地址的方法
2017/12/20 NodeJs
AngularJS $http post 传递参数数据的方法
2018/10/09 Javascript
Vue请求JSON Server服务器数据的实现方法
2018/11/02 Javascript
Vue组件之高德地图地址选择功能的实例代码
2019/06/21 Javascript
JS实现利用闭包判断Dom元素和滚动条的方向示例
2019/08/26 Javascript
分享Angular http interceptors 拦截器使用(推荐)
2019/11/10 Javascript
[02:47]3.19DOTA2发布会 国服成长历程回顾
2014/03/25 DOTA
[03:19]2016国际邀请赛中国区预选赛第四日TOP10镜头集锦
2016/07/01 DOTA
[39:07]LGD vs VP 2018国际邀请赛淘汰赛BO3 第二场 8.21
2018/08/22 DOTA
浅析python 中__name__ = '__main__' 的作用
2014/07/05 Python
Python正则表达式的使用范例详解
2014/08/08 Python
详解python中的Turtle函数库
2018/11/19 Python
使用 Python 在京东上抢口罩的思路详解
2020/02/27 Python
Python实现验证码识别
2020/06/15 Python
Clarria化妆品官方网站:购买天然和有机化妆品系列
2018/04/08 全球购物
英国现代家具和装饰网站:PN Home
2018/08/16 全球购物
亚洲领先的旅游体验市场:Voyagin
2019/11/23 全球购物
企业演讲稿范文大全
2014/05/20 职场文书
中央空调节能方案
2014/06/15 职场文书
单位实习工作证明怎么写
2014/11/02 职场文书
读后感作文评语
2014/12/25 职场文书
小学生暑假安全公约
2015/07/14 职场文书
微软PC Health Check电脑健康状况检查应用下载(Win11配置检测工具)
2021/06/26 数码科技
教你使用Ubuntu搭建DNS服务器
2022/09/23 Servers