def inflating_convolution(inputs, n_inflation_layers, projection_space_shape=(4, 4, 32), name_prefix=None):
assert len(projection_space_shape) == 3, \
"Projection space shape is {} but should be 3.".format(len(projection_space_shape))
flattened_space_dim = prod(projection_space_shape)
projection = Dense(flattened_space_dim, activation=None, name=name_prefix + '_projection')(inputs)
reshape = Reshape(projection_space_shape, name=name_prefix + '_reshape')(projection)
depth = projection_space_shape[2]
inflated = Conv2DTranspose(filters=min(32, depth // 2), kernel_size=(5, 5), strides=(2, 2), activation='relu',
padding='same', name=name_prefix + '_transposed_conv_0')(reshape)
for i in range(1, n_inflation_layers):
inflated = Conv2DTranspose(filters=max(1, depth // 2**(i+1)), kernel_size=(5, 5),
strides=(2, 2), activation='relu', padding='same',
name=name_prefix + '_transpose_conv_{}'.format(i))(inflated)
return inflated
architectures.py 文件源码
python
阅读 24
收藏 0
点赞 0
评论 0
评论列表
文章目录