simple.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:tensorbayes 作者: RuiShu 项目源码 文件源码
def conv2d_transpose(x,
                     num_outputs,
                     kernel_size,
                     strides,
                     padding='SAME',
                     output_shape=None,
                     output_like=None,
                     activation=None,
                     bn=False,
                     post_bn=False,
                     phase=None,
                     scope=None,
                     reuse=None):
    # Convert int to list
    kernel_size = [kernel_size] * 2 if isinstance(kernel_size, int) else kernel_size
    strides = [strides] * 2 if isinstance(strides, int) else strides

    # Convert list to valid list
    kernel_size = list(kernel_size) + [num_outputs, x.get_shape().dims[-1]]
    strides = [1] + list(strides) + [1]

    # Get output shape both as tensor obj and as list
    if output_shape:
        bs = tf.shape(x)[0]
        _output_shape = tf.stack([bs] + output_shape[1:])
    elif output_like:
        _output_shape = tf.shape(output_like)
        output_shape = output_like.get_shape()
    else:
        assert padding == 'SAME', "Shape inference only applicable with padding is SAME"
        bs, h, w, c = x._shape_as_list()
        bs_tf = tf.shape(x)[0]
        _output_shape = tf.stack([bs_tf, strides[1] * h, strides[2] * w, num_outputs])
        output_shape = [bs, strides[1] * h, strides[2] * w, num_outputs]

    # Transposed conv operation
    with tf.variable_scope(scope, 'conv2d', reuse=reuse):
        kernel = tf.get_variable('weights', kernel_size,
                                 initializer=variance_scaling_initializer())
        biases = tf.get_variable('biases', [num_outputs],
                                 initializer=tf.zeros_initializer)
        output = tf.nn.conv2d_transpose(x, kernel, _output_shape, strides,
                                        padding, name='conv2d_transpose')
        output += biases
        output.set_shape(output_shape)
        if bn: output = batch_norm(output, phase, scope='bn')
        if activation: output = activation(output)
        if post_bn: output = batch_norm(output, phase, scope='post_bn')

    return output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号