Pytorch转keras的有效方法,以FlowNet为例讲解


Posted in Python onMay 26, 2020

Pytorch凭借动态图机制,获得了广泛的使用,大有超越tensorflow的趋势,不过在工程应用上,TF仍然占据优势。有的时候我们会遇到这种情况,需要把模型应用到工业中,运用到实际项目上,TF支持的PB文件和TF的C++接口就成为了有效的工具。今天就给大家讲解一下Pytorch转成Keras的方法,进而我们也可以获得Pb文件,因为Keras是支持tensorflow的,我将会在下一篇博客讲解获得Pb文件,并使用Pb文件的方法。

Pytorch To Keras

首先,我们必须有清楚的认识,网上以及github上一些所谓的pytorch转换Keras或者Keras转换成Pytorch的工具代码几乎不能运行或者有使用的局限性(比如仅仅能转换某一些模型),但是我们是可以用这些转换代码中看出一些端倪来,比如二者的参数的尺寸(shape)的形式、channel的排序(first or last)是否一样,掌握到差异性,就能根据这些差异自己编写转换代码,没错,自己编写转换代码,是最稳妥的办法。整个过程也就分为两个部分。笔者将会以Nvidia开源的FlowNet为例,将开源的Pytorch代码转化为Keras模型。

按照Pytorch中模型的结构,编写对应的Keras代码,用keras的函数式API,构建起来会非常方便。

把Pytorch的模型参数,按照层的名称依次赋值给Keras的模型

以上两步虽然看上去简单,但实际我也走了不少弯路。这里一个关键的地方,就是参数的shape在两个框架中是否统一,那当然是不统一的。下面我以FlowNet为例。

Pytorch中的FlowNet代码

我们仅仅展示层名称和层参数,就不把整个结构贴出来了,否则会占很多的空间,形成水文。

先看用Keras搭建的flowNet模型,直接用model.summary()输出模型信息

__________________________________________________________________________________________________
Layer (type)   Output Shape  Param # Connected to   
==================================================================================================
input_1 (InputLayer)  (None, 6, 512, 512) 0      
__________________________________________________________________________________________________
conv0 (Conv2D)   (None, 64, 512, 512) 3520 input_1[0][0]   
__________________________________________________________________________________________________
leaky_re_lu_1 (LeakyReLU) (None, 64, 512, 512) 0  conv0[0][0]   
__________________________________________________________________________________________________
zero_padding2d_1 (ZeroPadding2D (None, 64, 514, 514) 0  leaky_re_lu_1[0][0]  
__________________________________________________________________________________________________
conv1 (Conv2D)   (None, 64, 256, 256) 36928 zero_padding2d_1[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_2 (LeakyReLU) (None, 64, 256, 256) 0  conv1[0][0]   
__________________________________________________________________________________________________
conv1_1 (Conv2D)  (None, 128, 256, 256 73856 leaky_re_lu_2[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_3 (LeakyReLU) (None, 128, 256, 256 0  conv1_1[0][0]   
__________________________________________________________________________________________________
zero_padding2d_2 (ZeroPadding2D (None, 128, 258, 258 0  leaky_re_lu_3[0][0]  
__________________________________________________________________________________________________
conv2 (Conv2D)   (None, 128, 128, 128 147584 zero_padding2d_2[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_4 (LeakyReLU) (None, 128, 128, 128 0  conv2[0][0]   
__________________________________________________________________________________________________
conv2_1 (Conv2D)  (None, 128, 128, 128 147584 leaky_re_lu_4[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_5 (LeakyReLU) (None, 128, 128, 128 0  conv2_1[0][0]   
__________________________________________________________________________________________________
zero_padding2d_3 (ZeroPadding2D (None, 128, 130, 130 0  leaky_re_lu_5[0][0]  
__________________________________________________________________________________________________
conv3 (Conv2D)   (None, 256, 64, 64) 295168 zero_padding2d_3[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_6 (LeakyReLU) (None, 256, 64, 64) 0  conv3[0][0]   
__________________________________________________________________________________________________
conv3_1 (Conv2D)  (None, 256, 64, 64) 590080 leaky_re_lu_6[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_7 (LeakyReLU) (None, 256, 64, 64) 0  conv3_1[0][0]   
__________________________________________________________________________________________________
zero_padding2d_4 (ZeroPadding2D (None, 256, 66, 66) 0  leaky_re_lu_7[0][0]  
__________________________________________________________________________________________________
conv4 (Conv2D)   (None, 512, 32, 32) 1180160 zero_padding2d_4[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_8 (LeakyReLU) (None, 512, 32, 32) 0  conv4[0][0]   
__________________________________________________________________________________________________
conv4_1 (Conv2D)  (None, 512, 32, 32) 2359808 leaky_re_lu_8[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_9 (LeakyReLU) (None, 512, 32, 32) 0  conv4_1[0][0]   
__________________________________________________________________________________________________
zero_padding2d_5 (ZeroPadding2D (None, 512, 34, 34) 0  leaky_re_lu_9[0][0]  
__________________________________________________________________________________________________
conv5 (Conv2D)   (None, 512, 16, 16) 2359808 zero_padding2d_5[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_10 (LeakyReLU) (None, 512, 16, 16) 0  conv5[0][0]   
__________________________________________________________________________________________________
conv5_1 (Conv2D)  (None, 512, 16, 16) 2359808 leaky_re_lu_10[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_11 (LeakyReLU) (None, 512, 16, 16) 0  conv5_1[0][0]   
__________________________________________________________________________________________________
zero_padding2d_6 (ZeroPadding2D (None, 512, 18, 18) 0  leaky_re_lu_11[0][0]  
__________________________________________________________________________________________________
conv6 (Conv2D)   (None, 1024, 8, 8) 4719616 zero_padding2d_6[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_12 (LeakyReLU) (None, 1024, 8, 8) 0  conv6[0][0]   
__________________________________________________________________________________________________
conv6_1 (Conv2D)  (None, 1024, 8, 8) 9438208 leaky_re_lu_12[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_13 (LeakyReLU) (None, 1024, 8, 8) 0  conv6_1[0][0]   
__________________________________________________________________________________________________
deconv5 (Conv2DTranspose) (None, 512, 16, 16) 8389120 leaky_re_lu_13[0][0]  
__________________________________________________________________________________________________
predict_flow6 (Conv2D)  (None, 2, 8, 8) 18434 leaky_re_lu_13[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_14 (LeakyReLU) (None, 512, 16, 16) 0  deconv5[0][0]   
__________________________________________________________________________________________________
upsampled_flow6_to_5 (Conv2DTra (None, 2, 16, 16) 66  predict_flow6[0][0]  
__________________________________________________________________________________________________
concatenate_1 (Concatenate) (None, 1026, 16, 16) 0  leaky_re_lu_11[0][0]  
         leaky_re_lu_14[0][0]  
         upsampled_flow6_to_5[0][0] 
__________________________________________________________________________________________________
inter_conv5 (Conv2D)  (None, 512, 16, 16) 4728320 concatenate_1[0][0]  
__________________________________________________________________________________________________
deconv4 (Conv2DTranspose) (None, 256, 32, 32) 4202752 concatenate_1[0][0]  
__________________________________________________________________________________________________
predict_flow5 (Conv2D)  (None, 2, 16, 16) 9218 inter_conv5[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_15 (LeakyReLU) (None, 256, 32, 32) 0  deconv4[0][0]   
__________________________________________________________________________________________________
upsampled_flow5_to4 (Conv2DTran (None, 2, 32, 32) 66  predict_flow5[0][0]  
__________________________________________________________________________________________________
concatenate_2 (Concatenate) (None, 770, 32, 32) 0  leaky_re_lu_9[0][0]  
         leaky_re_lu_15[0][0]  
         upsampled_flow5_to4[0][0] 
__________________________________________________________________________________________________
inter_conv4 (Conv2D)  (None, 256, 32, 32) 1774336 concatenate_2[0][0]  
__________________________________________________________________________________________________
deconv3 (Conv2DTranspose) (None, 128, 64, 64) 1577088 concatenate_2[0][0]  
__________________________________________________________________________________________________
predict_flow4 (Conv2D)  (None, 2, 32, 32) 4610 inter_conv4[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_16 (LeakyReLU) (None, 128, 64, 64) 0  deconv3[0][0]   
__________________________________________________________________________________________________
upsampled_flow4_to3 (Conv2DTran (None, 2, 64, 64) 66  predict_flow4[0][0]  
__________________________________________________________________________________________________
concatenate_3 (Concatenate) (None, 386, 64, 64) 0  leaky_re_lu_7[0][0]  
         leaky_re_lu_16[0][0]  
         upsampled_flow4_to3[0][0] 
__________________________________________________________________________________________________
inter_conv3 (Conv2D)  (None, 128, 64, 64) 444800 concatenate_3[0][0]  
__________________________________________________________________________________________________
deconv2 (Conv2DTranspose) (None, 64, 128, 128) 395328 concatenate_3[0][0]  
__________________________________________________________________________________________________
predict_flow3 (Conv2D)  (None, 2, 64, 64) 2306 inter_conv3[0][0]  
__________________________________________________________________________________________________
leaky_re_lu_17 (LeakyReLU) (None, 64, 128, 128) 0  deconv2[0][0]   
__________________________________________________________________________________________________
upsampled_flow3_to2 (Conv2DTran (None, 2, 128, 128) 66  predict_flow3[0][0]  
__________________________________________________________________________________________________
concatenate_4 (Concatenate) (None, 194, 128, 128 0  leaky_re_lu_5[0][0]  
         leaky_re_lu_17[0][0]  
         upsampled_flow3_to2[0][0] 
__________________________________________________________________________________________________
inter_conv2 (Conv2D)  (None, 64, 128, 128) 111808 concatenate_4[0][0]  
__________________________________________________________________________________________________
predict_flow2 (Conv2D)  (None, 2, 128, 128) 1154 inter_conv2[0][0]  
__________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D) (None, 2, 512, 512) 0  predict_flow2[0][0]

再看看Pytorch搭建的flownet模型

(conv0): Sequential(
 (0): Conv2d(6, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (conv1): Sequential(
 (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (conv1_1): Sequential(
 (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (conv2): Sequential(
 (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (conv2_1): Sequential(
 (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (conv3): Sequential(
 (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (conv3_1): Sequential(
 (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (conv4): Sequential(
 (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (conv4_1): Sequential(
 (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (conv5): Sequential(
 (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (conv5_1): Sequential(
 (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (conv6): Sequential(
 (0): Conv2d(512, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (conv6_1): Sequential(
 (0): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (deconv5): Sequential(
 (0): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (deconv4): Sequential(
 (0): ConvTranspose2d(1026, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (deconv3): Sequential(
 (0): ConvTranspose2d(770, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (deconv2): Sequential(
 (0): ConvTranspose2d(386, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
 )
 (inter_conv5): Sequential(
 (0): Conv2d(1026, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 )
 (inter_conv4): Sequential(
 (0): Conv2d(770, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 )
 (inter_conv3): Sequential(
 (0): Conv2d(386, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 )
 (inter_conv2): Sequential(
 (0): Conv2d(194, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 )
 (predict_flow6): Conv2d(1024, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (predict_flow5): Conv2d(512, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (predict_flow4): Conv2d(256, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (predict_flow3): Conv2d(128, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (predict_flow2): Conv2d(64, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (upsampled_flow6_to_5): ConvTranspose2d(2, 2, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
 (upsampled_flow5_to_4): ConvTranspose2d(2, 2, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
 (upsampled_flow4_to_3): ConvTranspose2d(2, 2, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
 (upsampled_flow3_to_2): ConvTranspose2d(2, 2, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
 (upsample1): Upsample(scale_factor=4.0, mode=bilinear)
)
conv0 Sequential(
 (0): Conv2d(6, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
conv0.0 Conv2d(6, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
conv0.1 LeakyReLU(negative_slope=0.1, inplace)
conv1 Sequential(
 (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
conv1.0 Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
conv1.1 LeakyReLU(negative_slope=0.1, inplace)
conv1_1 Sequential(
 (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
conv1_1.0 Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
conv1_1.1 LeakyReLU(negative_slope=0.1, inplace)
conv2 Sequential(
 (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
conv2.0 Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
conv2.1 LeakyReLU(negative_slope=0.1, inplace)
conv2_1 Sequential(
 (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
conv2_1.0 Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
conv2_1.1 LeakyReLU(negative_slope=0.1, inplace)
conv3 Sequential(
 (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
conv3.0 Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
conv3.1 LeakyReLU(negative_slope=0.1, inplace)
conv3_1 Sequential(
 (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
conv3_1.0 Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
conv3_1.1 LeakyReLU(negative_slope=0.1, inplace)
conv4 Sequential(
 (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
conv4.0 Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
conv4.1 LeakyReLU(negative_slope=0.1, inplace)
conv4_1 Sequential(
 (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
conv4_1.0 Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
conv4_1.1 LeakyReLU(negative_slope=0.1, inplace)
conv5 Sequential(
 (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
conv5.0 Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
conv5.1 LeakyReLU(negative_slope=0.1, inplace)
conv5_1 Sequential(
 (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
conv5_1.0 Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
conv5_1.1 LeakyReLU(negative_slope=0.1, inplace)
conv6 Sequential(
 (0): Conv2d(512, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
conv6.0 Conv2d(512, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
conv6.1 LeakyReLU(negative_slope=0.1, inplace)
conv6_1 Sequential(
 (0): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
conv6_1.0 Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
conv6_1.1 LeakyReLU(negative_slope=0.1, inplace)
deconv5 Sequential(
 (0): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
deconv5.0 ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
deconv5.1 LeakyReLU(negative_slope=0.1, inplace)
deconv4 Sequential(
 (0): ConvTranspose2d(1026, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
deconv4.0 ConvTranspose2d(1026, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
deconv4.1 LeakyReLU(negative_slope=0.1, inplace)
deconv3 Sequential(
 (0): ConvTranspose2d(770, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
deconv3.0 ConvTranspose2d(770, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
deconv3.1 LeakyReLU(negative_slope=0.1, inplace)
deconv2 Sequential(
 (0): ConvTranspose2d(386, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
 (1): LeakyReLU(negative_slope=0.1, inplace)
)
deconv2.0 ConvTranspose2d(386, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
deconv2.1 LeakyReLU(negative_slope=0.1, inplace)
inter_conv5 Sequential(
 (0): Conv2d(1026, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
inter_conv5.0 Conv2d(1026, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
inter_conv4 Sequential(
 (0): Conv2d(770, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
inter_conv4.0 Conv2d(770, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
inter_conv3 Sequential(
 (0): Conv2d(386, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
inter_conv3.0 Conv2d(386, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
inter_conv2 Sequential(
 (0): Conv2d(194, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)

因为Pytorch模型用name_modules()输出不是按顺序的,动态图机制决定了只有在有数据流动之后才知道走过的路径。所以上面的顺序也是乱的。但我想表明的是,我用Keras搭建的模型确实是根据官方开源的Pytorch模型搭建的。

模型搭建完毕之后,就到了关键的步骤:给Keras模型赋值。

给Keras模型赋值

这个步骤其实注意三个点

Pytorch是channels_first的,Keras默认是channels_last,在代码开头加上这两句:

K.set_image_data_format(‘channels_first')
K.set_learning_phase(0)

众所周知,卷积层的权重是一个4维张量,那么,在Pytorch和keras中,卷积核的权重的形式是否一致的,那自然是不一致的,要不然我为啥还要写这一点。那么就涉及到Pytorch权重的变形。

既然卷积层权重形式在两个框架是不一致的,转置卷积自然也是不一致的。

我们先看看卷积层在两个框架中的形式

keras的卷积层权重形式

我们用以下代码看keras卷积层权重形式

for l in model.layers:
  print(l.name)
  for i, w in enumerate(l.get_weights()):
   print('%d'%i , w.shape)

第一个卷积层输出如下 0之后是卷积权重的shape,1之后的是偏置项

conv0
0 (3, 3, 6, 64)
1 (64,)

所以Keras的卷积层权重形式是[ height, width, input_channels, out_channels]

Pytorch的卷积层权重形式

net = FlowNet2SD()
 for n, m in net.named_parameters():
  print(n)
  print(m.data.size())

conv0.0.weight
torch.Size([64, 6, 3, 3])
conv0.0.bias
torch.Size([64])

用上面的代码得到所有层的参数的shape,同样找到第一个卷积层的参数,查看shape。

通过对比我们可以发现,Pytorch的卷积层shape是[ out_channels, input_channels, height, width]的形式。

那么我们在取出Pytorch权重之后,需要用np.transpose改变一下权重的排序,才能送到Keras模型对应的层上。

Keras中转置卷积权重形式

deconv4
0 (4, 4, 256, 1026)
1 (256,)

代码仍然和上面一样,找到转置卷积的对应的位置,查看一下

可以看出在Keras中,转置卷积形式是 [ height, width, out_channels, input_channels]

Pytorch中转置卷积权重形式

deconv4.0.weight
torch.Size([1026, 256, 4, 4])
deconv4.0.bias
torch.Size([256])

代码仍然和上面一样,找到转置卷积的对应的位置,查看一下

可以看出在Pytorch中,转置卷积形式是 [ input_channels,out_channels,height, width]

小结

对于卷积层来说,Pytorch的权重需要使用

np.transpose(weight.data.numpy(), [2, 3, 1, 0])

才能赋值给keras模型对应的层的权重。

对于转置卷积来说,通过对比其实也是一样的。不信你去试试嘛。O(∩_∩)O哈哈~

对于偏置项,两种模块都是一维的向量,不需要处理。

有的情况还可能需要通道颠倒一下,但是很少需要这样做。

weights[::-1,::-1,:,:]

赋值

结束了预处理之后,我们就进入第二步,开始赋值了。

先看预处理的代码:

for k,v in weights_from_torch.items():
 if 'bias' not in k:
  weights_from_torch[k] = v.data.numpy().transpose(2, 3, 1, 0)

赋值代码我只截了一部分供大家参考:

k_model = k_model()
for layer in k_model.layers:
 current_layer_name = layer.name
 if current_layer_name=='conv0':
  weights = [weights_from_torch['conv0.0.weight'],weights_from_torch['conv0.0.bias']]
  layer.set_weights(weights)
 elif current_layer_name=='conv1':
  weights = [weights_from_torch['conv1.0.weight'],weights_from_torch['conv1.0.bias']]
  layer.set_weights(weights)
 elif current_layer_name=='conv1_1':
  weights = [weights_from_torch['conv1_1.0.weight'],weights_from_torch['conv1_1.0.bias']]
  layer.set_weights(weights)

首先就是定义Keras模型,用layers获得所有层的迭代器。

遍历迭代器,对一个层赋予相应的值。

赋值需要用save_weights,其参数需要是一个列表,形式和get_weights的返回结果一致,即 [ conv_weights, bias_weights]

最后祝愿大家能实现自己模型的迁移。工程开源在了个人Github,有详细的使用介绍,并且包含使用数据,大家可以直接运行。

以上这篇Pytorch转keras的有效方法,以FlowNet为例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python删除windows垃圾文件的方法
Jul 14 Python
在python中使用正则表达式查找可嵌套字符串组
Oct 24 Python
python绘制双柱形图代码实例
Dec 14 Python
python获取点击的坐标画图形的方法
Jul 09 Python
python接口调用已训练好的caffe模型测试分类方法
Aug 26 Python
NumPy统计函数的实现方法
Jan 21 Python
详解Python中如何将数据存储为json格式的文件
Nov 18 Python
pycharm 关闭search everywhere的解决操作
Jan 15 Python
Python tkinter实现日期选择器
Feb 22 Python
Python爬取英雄联盟MSI直播间弹幕并生成词云图
Jun 01 Python
Python 恐龙跑跑小游戏实现流程
Feb 15 Python
yolov5返回坐标的方法实例
Mar 17 Python
Django+Celery实现动态配置定时任务的方法示例
May 26 #Python
python删除某个目录文件夹的方法
May 26 #Python
Pytorch使用PIL和Numpy将单张图片转为Pytorch张量方式
May 25 #Python
Pytorch转onnx、torchscript方式
May 25 #Python
使用pandas库对csv文件进行筛选保存
May 25 #Python
pytorch中 gpu与gpu、gpu与cpu 在load时相互转化操作
May 25 #Python
基于pandas向csv添加新的行和列
May 25 #Python
You might like
安健A254立体声随身听的分析与打磨
2021/03/02 无线电
使用 MySQL Date/Time 类型
2008/03/26 PHP
MySQL数据源表结构图示
2008/06/05 PHP
PHP读取大文件的类SplFileObject使用介绍
2014/04/09 PHP
PHP设计模式(一)工厂模式Factory实例详解【创建型】
2020/05/02 PHP
认识延迟时间为0的setTimeout
2008/05/16 Javascript
Jquery 最近浏览过的商品的功能实现代码
2010/05/14 Javascript
Jquery实现简单的动画效果代码
2012/03/18 Javascript
Js 代码中,ajax请求地址后加随机数防止浏览器缓存的原因
2013/05/07 Javascript
用js来刷新当前页面保留参数的具体实现
2013/12/23 Javascript
使用变量动态设置js的属性名
2014/10/19 Javascript
基于jQuery实现多标签页切换的效果(web前端开发)
2016/07/24 Javascript
jQuery+HTML5实现弹出创意搜索框层
2016/12/29 Javascript
Angularjs之ngModel中的值验证绑定方法
2018/09/13 Javascript
LayUi使用switch开关,动态的去控制它是否被启用的方法
2019/09/21 Javascript
微信浏览器左上角返回按钮监听的实现
2020/03/04 Javascript
vue-resource post数据时碰到Django csrf问题的解决
2020/03/13 Javascript
基于vue实现微博三方登录流程解析
2020/11/04 Javascript
[02:41]DOTA2英雄基础教程 亚巴顿
2014/01/02 DOTA
python实现的文件夹清理程序分享
2014/11/22 Python
Python常见加密模块用法分析【MD5,sha,crypt模块】
2017/05/24 Python
Python生成随机验证码代码实例解析
2020/06/09 Python
Django中F函数的使用示例代码详解
2020/07/06 Python
解决pycharm不能自动保存在远程linux中的问题
2021/02/06 Python
关联、聚合(Aggregation)以及组合(Composition)的区别
2012/02/29 面试题
几个常见的软件测试问题
2016/09/07 面试题
工地门卫岗位职责
2013/12/30 职场文书
小学教师听课制度
2014/02/01 职场文书
学校招生宣传广告词
2014/03/19 职场文书
校长寄语大全
2014/04/09 职场文书
教师爱岗敬业演讲稿
2014/05/05 职场文书
局机关干部群众路线个人对照检查材料思想汇报
2014/10/05 职场文书
店面出租协议书范本
2014/11/28 职场文书
离职感谢信
2015/01/21 职场文书
教你如何使用Python Tkinter库制作记事本
2021/06/10 Python
Java实现注册登录跳转
2022/06/16 Java/Android