architectures.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:adversarial-variational-bayes 作者: gdikov 项目源码 文件源码
def inflating_convolution(inputs, n_inflation_layers, projection_space_shape=(4, 4, 32), name_prefix=None):
    assert len(projection_space_shape) == 3, \
        "Projection space shape is {} but should be 3.".format(len(projection_space_shape))
    flattened_space_dim = prod(projection_space_shape)
    projection = Dense(flattened_space_dim, activation=None, name=name_prefix + '_projection')(inputs)
    reshape = Reshape(projection_space_shape, name=name_prefix + '_reshape')(projection)
    depth = projection_space_shape[2]
    inflated = Conv2DTranspose(filters=min(32, depth // 2), kernel_size=(5, 5), strides=(2, 2), activation='relu',
                               padding='same', name=name_prefix + '_transposed_conv_0')(reshape)
    for i in range(1, n_inflation_layers):
        inflated = Conv2DTranspose(filters=max(1, depth // 2**(i+1)), kernel_size=(5, 5),
                                   strides=(2, 2), activation='relu', padding='same',
                                   name=name_prefix + '_transpose_conv_{}'.format(i))(inflated)
    return inflated
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号