ops.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:hyperchamber 作者: 255BITS 项目源码 文件源码
def constrained_conv2d(input_, output_dim,
           k_h=6, k_w=6, d_h=2, d_w=2, stddev=0.02,
           name="conv2d"):
    assert k_h % d_h == 0
    assert k_w % d_w == 0
    # constrained to have stride be a factor of kernel width
    # this is intended to reduce convolution artifacts
    with tf.variable_scope(name):
        w = tf.get_variable('w', [k_h, k_w, input_.get_shape()[-1], output_dim],
                            initializer=tf.truncated_normal_initializer(stddev=stddev))

        # This is meant to reduce boundary artifacts
        padded = tf.pad(input_, [[0, 0],
            [k_h-1, 0],
            [k_w-1, 0],
            [0, 0]])
        conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding='SAME')

        biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0))
        conv = tf.nn.bias_add(conv, biases)

        return conv
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号