ops.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:speech-enhancement-WGAN 作者: jerrygood0703 项目源码 文件源码
def conv2d(batch_input, out_channels, filter_shape, strides, name="conv"):
    with tf.variable_scope(name):
        in_channels = batch_input.get_shape()[1]
        in_height = batch_input.get_shape()[2]
        in_width = batch_input.get_shape()[3]
        kh, kw = filter_shape
        _, _, sh, sw = strides
        w = tf.get_variable(name="w",
                            shape=[kh, kw, in_channels, out_channels], 
                            dtype=tf.float32, 
                            initializer=tf.random_normal_initializer(0, 0.02))
        # b = tf.get_variable(name='b',
        #                     shape=[out_channels],
        #                     initializer=tf.constant_initializer(0.0))

        ph = pad_numbers(int(in_height), kh, sh)
        pw = pad_numbers(int(in_width), kw, sw)

        padded_input = tf.pad(batch_input, [[0, 0], [0, 0], ph, pw], mode="REFLECT")
        # conv = tf.nn.bias_add(tf.nn.conv2d(padded_input, w, strides, padding="VALID", data_format="NCHW"), b, data_format="NCHW")
        conv = tf.nn.conv2d(padded_input, w, strides, padding="VALID", data_format="NCHW")
        return conv
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号