Python通过VGG16模型实现图像风格转换操作详解


Posted in Python onJanuary 16, 2020

本文实例讲述了Python通过VGG16模型实现图像风格转换操作。分享给大家供大家参考,具体如下:

1、图像的风格转化

卷积网络每一层的激活值可以看作一个分类器,多个分类器组成了图像在这一层的抽象表示,而且层数越深,越抽象

内容特征:图片中存在的具体元素,图像输入到CNN后在某一层的激活值

风格特征:绘制图片元素的风格,各个内容之间的共性,图像在CNN网络某一层激活值之间的关联

风格转换:在一幅图片内容特征的基础上添加另一幅图片的风格特征从而生成一幅新的图片。在卷积模型训练中,通过输入固定的图片来调整网络的参数从而达到利用图片训练网络的目的。而在生成特定风格图片时,固定已有的网络参数不变,调整图片从而使图片向目标风格转化。在内容风格转换时,调整图像的像素值,使其向目标图片在卷积网络输出的内容特征靠拢。在风格特征计算时,通过多个神经元的输出两两之间作内积求和得到Gram矩阵,然后对G矩阵做差求均值得到风格的损失函数。

Python通过VGG16模型实现图像风格转换操作详解                             Python通过VGG16模型实现图像风格转换操作详解

将内容损失函数和风格损失函数对应乘以权重再加起来就得到了总的损失函数,最后的生成图既有内容特征也有风格特征

2、通过Vgg16实现

2.1、预训练模型读取

通过预训练好的Vgg16模型来对图片进行风格转换,首先需要准备好vgg16的模型参数。链接: https://pan.baidu.com/s/1shw2M3Iv7UfGjn78dqFAkA 提取码: ejn8

通过numpy.load()导入并查看参数的内容:

import numpy as np
 
data=np.load('./vgg16_model.npy',allow_pickle=True,encoding='bytes')
# print(data.type())
data_dic=data.item()
# 查看网络层参数的键值
print(data_dic.keys())

打印键值如下,可以看到分别有不同的卷积和全连接层:

dict_keys([b'conv5_1', b'fc6', b'conv5_3', b'conv5_2', b'fc8', b'fc7', b'conv4_1',
 b'conv4_2', b'conv4_3', b'conv3_3', b'conv3_2', b'conv3_1', b'conv1_1', b'conv1_2', 
b'conv2_2', b'conv2_1'])

接着查看具体每层的参数,通过data_dic[key]可以获取到key对应层次的参数,例如可以看到卷积层1_1的权值w为3个3×3的卷积核,对应64个输出通道

# 查看卷积层1_1的参数w,b
w,b=data_dic[b'conv1_1']
print(w.shape,b.shape)   # (3, 3, 3, 64) (64,)
# 查看全连接层的参数
w,b=data_dic[b'fc8']
print(w.shape,b.shape)   # (4096, 1000) (1000,)

2.2、构建VGG网络

通过将已经训练好的参数填充到网络之中就可以搭建VGG网络了。

在类初始化函数中读取预训练模型文件中的参数到self.data_dic

首先构建卷积层,通过传入的各个卷积层name参数,读取模型中对应的卷积层参数并填充到网络中。例如读取第一个卷积层的权值和偏置值,传入name='conv1_1,则data_dic[name][0]可以得到权值weight,data_dic[name][1]得到偏置值bias。通过tf.constant构建常量,再执行卷积操作,加偏置项,经激活函数后输出。

接下来实现池化操作,由于池化不需要参数,所以直接对输入进行最大池化操作后输出即可

接着经过展开层,由于卷积池化后的数据是四维向量[batch_size,image_width,image_height,chanel],需要将最后三维展开,将最后三个维度相乘,通过tf.reshape()展开

最后需要把结果经过全连接层,它的实现和卷积层类似,读取权值和偏置参数后进行全连接操作后输出。

class VGGNet:
 def __init__(self, data_dir):
  data = np.load(data_dir, allow_pickle=True, encoding='bytes')
  self.data_dic = data.item()
 
 def conv_layer(self, x, name):
  # 实现卷积操作
  with tf.name_scope(name):
   # 从模型文件中读取各卷积层的参数值
   weight = tf.constant(self.data_dic[name][0], name='conv')
   bias = tf.constant(self.data_dic[name][1], name='bias')
   # 进行卷积操作
   y = tf.nn.conv2d(x, weight, [1, 1, 1, 1], padding='SAME')
   y = tf.nn.bias_add(y, bias)
   return tf.nn.relu(y)
 
 def pooling_layer(self, x, name):
  # 实现池化操作
  return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name=name)
 
 def flatten_layer(self, x, name):
  # 实现展开层
  with tf.name_scope(name):
   # x_shape->[batch_size,image_width,image_height,chanel]
   x_shape = x.get_shape().as_list()
   dimension = 1
   # 计算x的最后三个维度积
   for d in x_shape[1:]:
    dimension *= d
   output = tf.reshape(x, [-1, dimension])
   return output
 
 def fc_layer(self, x, name, activation=tf.nn.relu):
  # 实现全连接层
  with tf.name_scope(name):
   # 从模型文件中读取各全连接层的参数值
   weight = tf.constant(self.data_dic[name][0], name='fc')
   bias = tf.constant(self.data_dic[name][1], name='bias')
   # 进行全连接操作
   y = tf.matmul(x, weight)
   y = tf.nn.bias_add(y, bias)
   if activation==None:
    return y
   else:
    return tf.nn.relu(y)

通过self.build()函数实现Vgg16网络的搭建.数据输入后首先需要进行归一化处理,将输入的RGB数据拆分为R、G、B三个通道,再将三个通道分别减去一个固定值,最后将三通道按B、G、R顺序重新拼接为一个新的数据。

接下来则是通过上面的构建函数来搭建VGG网络,依次将五层的卷积池化网络、展开层、三个全连接层的参数读入各层,并搭建起网络,最后经softmax输出

def build(self,x_rgb):
  s_time=time.time()
  # 归一化处理,在第四维上将输入的图片的三通道拆分
  r,g,b=tf.split(x_rgb,[1,1,1],axis=3)
  # 分别将三通道上减去特定值归一化后再按bgr顺序拼起来
  VGG_MEAN = [103.939, 116.779, 123.68]
  x_bgr=tf.concat(
   [b-VGG_MEAN[0],
   g-VGG_MEAN[1],
   r-VGG_MEAN[2]],
   axis=3
  )
  # 判别拼接起来的数据是否符合期望,符合再继续往下执行
  assert x_bgr.get_shape()[1:]==[668,668,3]
 
  # 构建各个卷积、池化、全连接等层
  self.conv1_1=self.conv_layer(x_bgr,b'conv1_1')
  self.conv1_2=self.conv_layer(self.conv1_1,b'conv1_2')
  self.pool1=self.pooling_layer(self.conv1_2,b'pool1')
 
  self.conv2_1=self.conv_layer(self.pool1,b'conv2_1')
  self.conv2_2=self.conv_layer(self.conv2_1,b'conv2_2')
  self.pool2=self.pooling_layer(self.conv2_2,b'pool2')
 
  self.conv3_1=self.conv_layer(self.pool2,b'conv3_1')
  self.conv3_2=self.conv_layer(self.conv3_1,b'conv3_2')
  self.conv3_3=self.conv_layer(self.conv3_2,b'conv3_3')
  self.pool3=self.pooling_layer(self.conv3_3,b'pool3')
 
  self.conv4_1 = self.conv_layer(self.pool3, b'conv4_1')
  self.conv4_2 = self.conv_layer(self.conv4_1, b'conv4_2')
  self.conv4_3 = self.conv_layer(self.conv4_2, b'conv4_3')
  self.pool4 = self.pooling_layer(self.conv4_3, b'pool4')
 
  self.conv5_1 = self.conv_layer(self.pool4, b'conv5_1')
  self.conv5_2 = self.conv_layer(self.conv5_1, b'conv5_2')
  self.conv5_3 = self.conv_layer(self.conv5_2, b'conv5_3')
  self.pool5 = self.pooling_layer(self.conv5_3, b'pool5')
 
  self.flatten=self.flatten_layer(self.pool5,b'flatten')
  self.fc6=self.fc_layer(self.flatten,b'fc6')
  self.fc7 = self.fc_layer(self.fc6, b'fc7')
  self.fc8 = self.fc_layer(self.fc7, b'fc8',activation=None)
  self.prob=tf.nn.softmax(self.fc8,name='prob')
 
  print('模型构建完成,用时%d秒'%(time.time()-s_time))

2.3、图像风格转换

首先需要定义网络的输入与输出。网络的输入是风格图像和内容图像,两张图象都是668×668的3通道图片。首先通过PIL库中的Image对象完成读入内容图像style_img和风格图像content_img,并将其转化为数组,定义对应的占位符style_in和content_in,在训练时将图片填入。

网络的输出是一张结果图片668×668的3通道,通过随机函数初始化一个结果图像的数组res_out。

利用上面定义的VGGNet类来创建图片对象,并完成build操作。

vgg16_dir = './data/vgg16_model.npy'
style_img = './data/starry_night.jpg'
content_img = './data/city_night.jpg'
output_dir = './data'
 
 
def read_image(img):
 img = Image.open(img)
 img_np = np.array(img) # 将图片转化为[668,668,3]数组
 img_np = np.asarray([img_np], ) # 转化为[1,668,668,3]的数组
 return img_np
 
 
# 输入风格、内容图像数组
style_img = read_image(style_img)
content_img = read_image(content_img)
# 定义对应的输入图像的占位符
content_in = tf.placeholder(tf.float32, shape=[1, 668, 668, 3])
style_in = tf.placeholder(tf.float32, shape=[1, 668, 668, 3])
 
# 初始化输出的图像
initial_img = tf.truncated_normal((1, 668, 668, 3), mean=127.5, stddev=20)
res_out = tf.Variable(initial_img)
 
# 构建VGG网络对象
res_net = VGGNet(vgg16_dir)
style_net = VGGNet(vgg16_dir)
content_net = VGGNet(vgg16_dir)
res_net.build(res_out)
style_net.build(style_in)
content_net.build(content_in)

接着需要定义损失函数loss

对于内容损失,先选定内容风格图像和结果图像的卷积层,要相同,比如这里选取了卷积层1_1和2_1。然后这两个特征层的后三个通道求平方差,然后取均值,就是内容损失。

对于风格损失,首先需要对风格图像和结果图像的特征层求gram矩阵,然后对gram矩阵求平方差的均值。

最后按照系数比例将两个损失函数相加即可得到loss

# 计算损失,分别需要计算内容损失和风格损失
# 提取内容图像的内容特征
content_features = [
 content_net.conv1_2,
 content_net.conv2_2
 # content_net.conv2_2
]
# 对应结果图像提取相同层的内容特征
res_content = [
 res_net.conv1_2,
 res_net.conv2_2
 # res_net.conv2_2
]
# 计算内容损失
content_loss = tf.zeros(1, tf.float32)
for c, r in zip(content_features, res_content):
 content_loss += tf.reduce_mean((c - r) ** 2, [1, 2, 3])
 
 
# 计算风格损失的gram矩阵
def gram_matrix(x):
 b, w, h, ch = x.get_shape().as_list()
 features = tf.reshape(x, [b, w * h, ch])
 # 对features矩阵作内积,再除以一个常数
 gram = tf.matmul(features, features, adjoint_a=True) / tf.constant(w * h * ch, tf.float32)
 return gram
 
 
# 对风格图像提取特征
style_features = [
 # style_net.conv1_2
 style_net.conv4_3
]
style_gram = [gram_matrix(feature) for feature in style_features]
# 提取结果图像对应层的风格特征
res_features = [
 res_net.conv4_3
]
res_gram = [gram_matrix(feature) for feature in res_features]
# 计算风格损失
style_loss = tf.zeros(1, tf.float32)
for s, r in zip(style_gram, res_gram):
 style_loss += tf.reduce_mean((s - r) ** 2, [1, 2])
 
# 模型内容、风格特征的系数
k_content = 0.1
k_style = 500
# 按照系数将两个损失值相加
loss = k_content * content_loss + k_style * style_loss

接下来开始进行100轮的训练,打印并查看过程中的总损失、内容损失、风格损失值。并将每轮的生成结果图片输出到指定目录下

# 进行训练
learning_steps = 100
learning_rate = 10
train_op = tf.train.AdamOptimizer(learning_rate).minimize(loss)
 
with tf.Session() as sess:
 sess.run(tf.global_variables_initializer())
 for i in range(learning_steps):
  t_loss, c_loss, s_loss, _ = sess.run(
   [loss, content_loss, style_loss, train_op],
   feed_dict={content_in: content_img, style_in: style_img}
  )
  print('第%d轮训练,总损失:%.4f,内容损失:%.4f,风格损失:%.4f'
    % (i + 1, t_loss[0], c_loss[0], s_loss[0]))
  # 获取结果图像数组并保存
  res_arr = res_out.eval(sess)[0]
  res_arr = np.clip(res_arr, 0, 255) # 将结果数组中的值裁剪到0~255
  res_arr = np.asarray(res_arr, np.uint8) # 将图片数组转化为uint8
  img_path = os.path.join(output_dir, 'res_%d.jpg' % (i + 1))
  # 图像数组转化为图片
  res_img = Image.fromarray(res_arr)
  res_img.save(img_path)

运行结果如下可以看到依次分别为内容图片、风格图片、训练12轮、46轮、100轮结果图片

Python通过VGG16模型实现图像风格转换操作详解    Python通过VGG16模型实现图像风格转换操作详解 

Python通过VGG16模型实现图像风格转换操作详解      Python通过VGG16模型实现图像风格转换操作详解    Python通过VGG16模型实现图像风格转换操作详解

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
Python不使用int()函数把字符串转换为数字的方法
Jul 09 Python
python保存网页图片到本地的方法
Jul 24 Python
Python在for循环中更改list值的方法【推荐】
Aug 17 Python
使用PM2+nginx部署python项目的方法示例
Nov 07 Python
解决PySide+Python子线程更新UI线程的问题
Jan 11 Python
解决nohup执行python程序log文件写入不及时的问题
Jan 14 Python
python从入门到精通 windows安装python图文教程
May 18 Python
python异步实现定时任务和周期任务的方法
Jun 29 Python
python版百度语音识别功能
Jul 09 Python
Python TCP通信客户端服务端代码实例
Nov 21 Python
python正则表达式实例代码
Mar 03 Python
Python转换字典成为对象,可以用"."方式访问对象属性实例
May 11 Python
Python使用turtle库绘制小猪佩奇(实例代码)
Jan 16 #Python
PyCharm汉化安装及永久激活详细教程(靠谱)
Jan 16 #Python
python如何使用Redis构建分布式锁
Jan 16 #Python
Python中url标签使用知识点总结
Jan 16 #Python
PyTorch的SoftMax交叉熵损失和梯度用法
Jan 15 #Python
pytorch方法测试——激活函数(ReLU)详解
Jan 15 #Python
pytorch的batch normalize使用详解
Jan 15 #Python
You might like
PHP 中文乱码解决办法总结分析
2009/07/30 PHP
PHP数组循环操作详细介绍 附实例代码
2013/02/03 PHP
Ubuntu VPS中wordpress网站打开时提示”建立数据库连接错误”的解决办法
2016/11/03 PHP
PHP创建单例后台进程的方法示例
2017/05/23 PHP
解决laravel 5.1报错:No supported encrypter found的办法
2017/06/07 PHP
刷新时清空文本框内容的js代码
2007/04/23 Javascript
js Flash插入函数免激活代码
2009/03/31 Javascript
extjs fckeditor集成代码
2009/05/10 Javascript
javascript实现网页子页面遍历回调的方法(涉及 window.frames、递归函数、函数上下文)
2015/07/27 Javascript
jquery实现实时改变网页字体大小、字体背景色和颜色的方法
2015/08/05 Javascript
JavaScript从0开始构思表情插件
2016/07/26 Javascript
jquery动态添加文本并获取值的方法
2016/10/12 Javascript
小程序实现带年月选取效果的日历
2018/06/27 Javascript
vue-resource请求实现http登录拦截或者路由拦截的方法
2018/07/11 Javascript
详解Vue This$Store总结
2018/12/17 Javascript
JavaScript实现移动端拖动元素
2020/11/24 Javascript
[02:28]DOTA2英雄基础教程 灰烬之灵
2013/12/19 DOTA
Python中实现结构相似的函数调用方法
2015/03/10 Python
Python中用altzone()方法处理时区的教程
2015/05/22 Python
Python中装饰器兼容加括号和不加括号的写法详解
2017/07/05 Python
Python排序算法实例代码
2017/08/10 Python
python 不以科学计数法输出的方法
2018/07/16 Python
浅谈python3发送post请求参数为空的情况
2018/12/28 Python
用xpath获取指定标签下的所有text的实例
2019/01/02 Python
Python实现的多进程拷贝文件并显示百分比功能示例
2019/04/09 Python
CSS3 @font-face属性使用指南
2014/12/12 HTML / CSS
举例详解HTML5中使用JSON格式提交表单
2015/06/16 HTML / CSS
浅析HTML5中的 History 模式
2017/06/22 HTML / CSS
致跳远、跳高运动员广播稿
2014/01/09 职场文书
银行实习生的自我评价
2014/01/13 职场文书
法院四风对照检查材料思想汇报
2014/10/06 职场文书
简易离婚协议书范本2014
2014/10/15 职场文书
人事局接收函
2015/01/31 职场文书
写作技巧:如何撰写一份优秀的营销策划书
2019/08/13 职场文书
JavaScript实现贪吃蛇游戏
2021/06/16 Javascript
总结几个非常实用的Python库
2021/06/26 Python