pytorch GAN生成对抗网络实例


Posted in Python onJanuary 10, 2020

我就废话不多说了,直接上代码吧!

import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt

torch.manual_seed(1)
np.random.seed(1)

BATCH_SIZE = 64
LR_G = 0.0001
LR_D = 0.0001
N_IDEAS = 5
ART_COMPONENTS = 15
PAINT_POINTS = np.vstack([np.linspace(-1,1,ART_COMPONENTS) for _ in range(BATCH_SIZE)])

def artist_works():
	a = np.random.uniform(1,2,size=BATCH_SIZE)[:,np.newaxis]
	paintings = a*np.power(PAINT_POINTS,2) + (a-1)
	paintings = torch.from_numpy(paintings).float()
	return Variable(paintings)

G = nn.Sequential(
	nn.Linear(N_IDEAS,128),
	nn.ReLU(),
	nn.Linear(128,ART_COMPONENTS),
)

D = nn.Sequential(
	nn.Linear(ART_COMPONENTS,128),
	nn.ReLU(),
	nn.Linear(128,1),
	nn.Sigmoid(),
)

opt_D = torch.optim.Adam(D.parameters(),lr=LR_D)
opt_G = torch.optim.Adam(G.parameters(),lr=LR_G)

plt.ion()

for step in range(10000):
	artist_paintings = artist_works()
	G_ideas = Variable(torch.randn(BATCH_SIZE,N_IDEAS))
	G_paintings = G(G_ideas)

	prob_artist0 = D(artist_paintings)
	prob_artist1 = D(G_paintings)

	D_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1-prob_artist1))
	G_loss = torch.mean(torch.log(1 - prob_artist1))

	opt_D.zero_grad()
	D_loss.backward(retain_variables=True)
	opt_D.step()

	opt_G.zero_grad()
	G_loss.backward()
	opt_G.step()

	if step % 50 == 0:
		plt.cla()
		plt.plot(PAINT_POINTS[0],G_paintings.data.numpy()[0],c='#4ad631',lw=3,label='Generated painting',)
		plt.plot(PAINT_POINTS[0],2 * np.power(PAINT_POINTS[0], 2) + 1,c='#74BCFF',lw=3,label='upper bound',)
		plt.plot(PAINT_POINTS[0],1 * np.power(PAINT_POINTS[0], 2) + 0,c='#FF9359',lw=3,label='lower bound',)
		plt.text(-.5,2.3,'D accuracy=%.2f (0.5 for D to converge)' % prob_artist0.data.numpy().mean(), fontdict={'size':15})
		plt.text(-.5, 2, 'D score= %.2f (-1.38 for G to converge)' % -D_loss.data.numpy(), fontdict={'size': 15})
		plt.ylim((0,3))
		plt.legend(loc='upper right', fontsize=12)
		plt.draw()
		plt.pause(0.01)

plt.ioff()
plt.show()

pytorch GAN生成对抗网络实例

以上这篇pytorch GAN生成对抗网络实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
剖析Python的Tornado框架中session支持的实现代码
Aug 21 Python
Python编程实现正则删除命令功能
Aug 30 Python
Python数据分析之如何利用pandas查询数据示例代码
Sep 01 Python
Odoo中如何生成唯一不重复的序列号详解
Feb 10 Python
Win10+GPU版Pytorch1.1安装的安装步骤
Sep 27 Python
python中for循环变量作用域及用法详解
Nov 05 Python
Python上下文管理器用法及实例解析
Nov 11 Python
python和pywin32实现窗口查找、遍历和点击的示例代码
Apr 01 Python
pytorch cuda上tensor的定义 以及减少cpu的操作详解
Jun 23 Python
tensorflow 大于某个值为1,小于为0的实例
Jun 30 Python
python UDF 实现对csv批量md5加密操作
Jan 01 Python
Python网络编程之ZeroMQ知识总结
Apr 25 Python
解决pytorch报错:AssertionError: Invalid device id的问题
Jan 10 #Python
python3中关于excel追加写入格式被覆盖问题(实例代码)
Jan 10 #Python
mac使用python识别图形验证码功能
Jan 10 #Python
python列表推导和生成器表达式知识点总结
Jan 10 #Python
pytorch的梯度计算以及backward方法详解
Jan 10 #Python
Python如何获取Win7,Win10系统缩放大小
Jan 10 #Python
python-OpenCV 实现将数组转换成灰度图和彩图
Jan 09 #Python
You might like
PHP 动态随机生成验证码类代码
2010/04/09 PHP
Linux Apache PHP Oracle 安装配置(具体操作步骤)
2013/06/17 PHP
深入array multisort排序原理的详解
2013/06/18 PHP
PHP网页游戏学习之Xnova(ogame)源码解读(十六)
2014/06/30 PHP
php提交post数组参数实例分析
2015/12/17 PHP
PHP一致性hash分布式算法封装类定义与用法示例
2018/08/04 PHP
js select常用操作控制代码
2010/03/16 Javascript
AJAX 网页保留浏览器前进后退等功能
2011/02/12 Javascript
jquery取消选择select下拉框示例代码
2014/02/22 Javascript
jquery简单实现幻灯片的方法
2015/08/03 Javascript
基于javascript实现漂亮的页面过渡动画效果附源码下载
2015/10/26 Javascript
javascript精确统计网站访问量实例代码
2015/12/19 Javascript
微信小程序 支付简单实例及注意事项
2017/01/06 Javascript
详解vue-router2.0动态路由获取参数
2017/06/14 Javascript
简单谈谈JS中的正则表达式
2017/09/11 Javascript
jQuery获取所有父级元素及同级元素及子元素的方法(推荐)
2018/01/21 jQuery
JavaScript运行原理分析
2018/02/09 Javascript
详解Vue.js中引入图片路径的几种方式
2019/06/17 Javascript
vue2.0+SVG实现音乐播放圆形进度条组件
2019/09/21 Javascript
[07:37]DOTA2-DPC中国联赛2月2日Recap集锦
2021/03/11 DOTA
Python自定义函数的创建、调用和函数的参数详解
2014/03/11 Python
python使用append合并两个数组的方法
2015/04/28 Python
对python中的try、except、finally 执行顺序详解
2019/02/18 Python
python3.x+pyqt5实现主窗口状态栏里(嵌入)显示进度条功能
2019/07/04 Python
Python检测端口IP字符串是否合法
2020/06/05 Python
使用CSS3的appearance属性改变元素的外观的方法
2015/12/12 HTML / CSS
blueseventy官网:铁人三项和比赛泳衣
2021/02/06 全球购物
大四毕业生学习总结的自我评价
2013/10/31 职场文书
平面设计岗位职责
2013/12/14 职场文书
幼儿园运动会口号
2014/06/07 职场文书
道歉短信大全
2015/05/12 职场文书
关于分班的感言
2015/08/04 职场文书
MySQL入门命令之函数-单行函数-流程控制函数
2021/04/05 MySQL
MySQL慢查询的坑
2021/04/28 MySQL
Django给表单添加honeypot验证增加安全性
2021/05/06 Python
python字符串拼接.join()和拆分.split()详解
2021/11/23 Python