pytorch GAN生成对抗网络实例


Posted in Python onJanuary 10, 2020

我就废话不多说了,直接上代码吧!

import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt

torch.manual_seed(1)
np.random.seed(1)

BATCH_SIZE = 64
LR_G = 0.0001
LR_D = 0.0001
N_IDEAS = 5
ART_COMPONENTS = 15
PAINT_POINTS = np.vstack([np.linspace(-1,1,ART_COMPONENTS) for _ in range(BATCH_SIZE)])

def artist_works():
	a = np.random.uniform(1,2,size=BATCH_SIZE)[:,np.newaxis]
	paintings = a*np.power(PAINT_POINTS,2) + (a-1)
	paintings = torch.from_numpy(paintings).float()
	return Variable(paintings)

G = nn.Sequential(
	nn.Linear(N_IDEAS,128),
	nn.ReLU(),
	nn.Linear(128,ART_COMPONENTS),
)

D = nn.Sequential(
	nn.Linear(ART_COMPONENTS,128),
	nn.ReLU(),
	nn.Linear(128,1),
	nn.Sigmoid(),
)

opt_D = torch.optim.Adam(D.parameters(),lr=LR_D)
opt_G = torch.optim.Adam(G.parameters(),lr=LR_G)

plt.ion()

for step in range(10000):
	artist_paintings = artist_works()
	G_ideas = Variable(torch.randn(BATCH_SIZE,N_IDEAS))
	G_paintings = G(G_ideas)

	prob_artist0 = D(artist_paintings)
	prob_artist1 = D(G_paintings)

	D_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1-prob_artist1))
	G_loss = torch.mean(torch.log(1 - prob_artist1))

	opt_D.zero_grad()
	D_loss.backward(retain_variables=True)
	opt_D.step()

	opt_G.zero_grad()
	G_loss.backward()
	opt_G.step()

	if step % 50 == 0:
		plt.cla()
		plt.plot(PAINT_POINTS[0],G_paintings.data.numpy()[0],c='#4ad631',lw=3,label='Generated painting',)
		plt.plot(PAINT_POINTS[0],2 * np.power(PAINT_POINTS[0], 2) + 1,c='#74BCFF',lw=3,label='upper bound',)
		plt.plot(PAINT_POINTS[0],1 * np.power(PAINT_POINTS[0], 2) + 0,c='#FF9359',lw=3,label='lower bound',)
		plt.text(-.5,2.3,'D accuracy=%.2f (0.5 for D to converge)' % prob_artist0.data.numpy().mean(), fontdict={'size':15})
		plt.text(-.5, 2, 'D score= %.2f (-1.38 for G to converge)' % -D_loss.data.numpy(), fontdict={'size': 15})
		plt.ylim((0,3))
		plt.legend(loc='upper right', fontsize=12)
		plt.draw()
		plt.pause(0.01)

plt.ioff()
plt.show()

pytorch GAN生成对抗网络实例

以上这篇pytorch GAN生成对抗网络实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Django1.3添加app提示模块不存在的解决方法
Aug 26 Python
简述Python中的面向对象编程的概念
Apr 27 Python
python中实现指定时间调用函数示例代码
Sep 08 Python
Django 忘记管理员或忘记管理员密码 重设登录密码的方法
May 30 Python
python3实现多线程聊天室
Dec 12 Python
python截取两个单词之间的内容方法
Dec 25 Python
pandas factorize实现将字符串特征转化为数字特征
Dec 19 Python
Python基于Dlib的人脸识别系统的实现
Feb 26 Python
Python使用Numpy模块读取文件并绘制图片
May 13 Python
Python基于xlrd模块处理合并单元格
Jul 28 Python
如何使用flask将模型部署为服务
May 13 Python
Python编程中Python与GIL互斥锁关系作用分析
Sep 15 Python
解决pytorch报错:AssertionError: Invalid device id的问题
Jan 10 #Python
python3中关于excel追加写入格式被覆盖问题(实例代码)
Jan 10 #Python
mac使用python识别图形验证码功能
Jan 10 #Python
python列表推导和生成器表达式知识点总结
Jan 10 #Python
pytorch的梯度计算以及backward方法详解
Jan 10 #Python
Python如何获取Win7,Win10系统缩放大小
Jan 10 #Python
python-OpenCV 实现将数组转换成灰度图和彩图
Jan 09 #Python
You might like
JAVA/JSP学习系列之七
2006/10/09 PHP
golang 调用 php7详解及实例
2017/01/04 PHP
替换php字符串中的单引号为双引号的方法
2017/02/16 PHP
深入理解 PHP7 中全新的 zval 容器和引用计数机制
2018/10/15 PHP
关于Laravel参数验证的一些疑与惑
2019/11/19 PHP
PHP http请求超时问题解决方案
2020/11/13 PHP
JavaScrip实现PHP print_r的数功能(三种方法)
2013/11/12 Javascript
JS+DIV实现鼠标划过切换层效果的实例代码
2013/11/26 Javascript
快速掌握Node.js模块封装及使用
2016/03/21 Javascript
js实现登录框鼠标拖拽效果
2017/03/09 Javascript
使用travis-ci如何持续部署node.js应用详解
2017/07/30 Javascript
解决IE7中使用jQuery动态操作name问题
2017/08/28 jQuery
从parcel.js打包出错到选择nvm的全部过程
2018/01/23 Javascript
JavaScript继承与多继承实例分析
2018/05/26 Javascript
vue-cli2打包前和打包后的css前缀不一致的问题解决
2018/08/24 Javascript
nodejs npm错误Error:UNKNOWN:unknown error,mkdir 'D:\Develop\nodejs\node_global'at Error
2019/03/02 NodeJs
Nodejs异步流程框架async的方法
2019/06/07 NodeJs
axios实现文件上传并获取进度
2020/03/25 Javascript
[01:07:41]IG vs VGJ.T 2018国际邀请赛小组赛BO2 第一场 8.18
2018/08/19 DOTA
详解python开发环境搭建
2016/12/16 Python
解决python使用open打开文件中文乱码的问题
2017/12/29 Python
python实现控制电脑鼠标和键盘,登录QQ的方法示例
2019/07/06 Python
Pycharm小白级简单使用教程
2020/01/08 Python
在spyder IPython console中,运行代码加入参数的实例
2020/04/20 Python
python 基于卡方值分箱算法的实现示例
2020/07/17 Python
Python如何绘制日历图和热力图
2020/08/07 Python
html5实现输入框fixed定位在屏幕最底部兼容性
2020/07/03 HTML / CSS
"引用"与多态的关系
2013/02/01 面试题
应聘教师推荐信
2013/10/31 职场文书
绩效专员岗位职责
2013/12/02 职场文书
高中综合实践活动总结
2014/07/07 职场文书
感恩老师演讲稿600字
2014/08/28 职场文书
2014年教务处工作总结
2014/12/03 职场文书
整改通知书
2015/04/20 职场文书
oracle覆盖导入dmp文件的2种方法
2021/05/21 Oracle
python如何利用traceback获取详细的异常信息
2021/06/05 Python