def conv2d_transpose(x,
num_outputs,
kernel_size,
strides,
padding='SAME',
output_shape=None,
output_like=None,
activation=None,
bn=False,
post_bn=False,
phase=None,
scope=None,
reuse=None):
# Convert int to list
kernel_size = [kernel_size] * 2 if isinstance(kernel_size, int) else kernel_size
strides = [strides] * 2 if isinstance(strides, int) else strides
# Convert list to valid list
kernel_size = list(kernel_size) + [num_outputs, x.get_shape().dims[-1]]
strides = [1] + list(strides) + [1]
# Get output shape both as tensor obj and as list
if output_shape:
bs = tf.shape(x)[0]
_output_shape = tf.stack([bs] + output_shape[1:])
elif output_like:
_output_shape = tf.shape(output_like)
output_shape = output_like.get_shape()
else:
assert padding == 'SAME', "Shape inference only applicable with padding is SAME"
bs, h, w, c = x._shape_as_list()
bs_tf = tf.shape(x)[0]
_output_shape = tf.stack([bs_tf, strides[1] * h, strides[2] * w, num_outputs])
output_shape = [bs, strides[1] * h, strides[2] * w, num_outputs]
# Transposed conv operation
with tf.variable_scope(scope, 'conv2d', reuse=reuse):
kernel = tf.get_variable('weights', kernel_size,
initializer=variance_scaling_initializer())
biases = tf.get_variable('biases', [num_outputs],
initializer=tf.zeros_initializer)
output = tf.nn.conv2d_transpose(x, kernel, _output_shape, strides,
padding, name='conv2d_transpose')
output += biases
output.set_shape(output_shape)
if bn: output = batch_norm(output, phase, scope='bn')
if activation: output = activation(output)
if post_bn: output = batch_norm(output, phase, scope='post_bn')
return output
评论列表
文章目录