ops.py 文件源码

python
阅读 48 收藏 0 点赞 0 评论 0

项目:PixelDCN 作者: HongyangGao 项目源码 文件源码
def conv(inputs, out_num, kernel_size, scope, data_type='2D', norm=True):
    if data_type == '2D':
        outs = tf.layers.conv2d(
            inputs, out_num, kernel_size, padding='same', name=scope+'/conv',
            kernel_initializer=tf.truncated_normal_initializer)
    else:
        shape = list(kernel_size) + [inputs.shape[-1].value, out_num]
        weights = tf.get_variable(
            scope+'/conv/weights', shape,
            initializer=tf.truncated_normal_initializer())
        outs = tf.nn.conv3d(
            inputs, weights, (1, 1, 1, 1, 1), padding='SAME',
            name=scope+'/conv')
    if norm:
        return tf.contrib.layers.batch_norm(
            outs, decay=0.9, epsilon=1e-5, activation_fn=tf.nn.relu,
            updates_collections=None, scope=scope+'/batch_norm')
    else:
        return tf.contrib.layers.batch_norm(
            outs, decay=0.9, epsilon=1e-5, activation_fn=None,
            updates_collections=None, scope=scope+'/batch_norm')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号