ops.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:3D_Dense_Transformer_Networks 作者: JohnYC1995 项目源码 文件源码
def deconv(inputs, out_num, kernel_size, scope, data_type='2D'):
    if data_type == '2D':
        outputs = tf.layers.conv2d_transpose(
            inputs, out_num, kernel_size, (2, 2), padding='same', name=scope,
            kernel_initializer=tf.truncated_normal_initializer)
    else:
        shape = list(kernel_size) + [out_num, out_num]
        input_shape = inputs.shape.as_list()
        out_shape = [input_shape[0]] + \
            list(map(lambda x: x*2, input_shape[1:-1])) + [out_num]
        weights = tf.get_variable(
            scope+'/deconv/weights', shape, initializer=tf.truncated_normal_initializer())
        outputs = tf.nn.conv3d_transpose(
            inputs, weights, out_shape, (1, 2, 2, 2, 1), name=scope+'/deconv')
    return tf.contrib.layers.batch_norm(
        outputs, decay=0.9, epsilon=1e-5, activation_fn=tf.nn.relu,
        updates_collections=None, scope=scope+'/batch_norm')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号