def deconv_stride2_multistep(x,
nbr_steps,
output_filters,
name=None,
reuse=None):
"""Use a deconvolution to upsample x by 2**`nbr_steps`.
Args:
x: a `Tensor` with shape `[batch, spatial, depth]` or
`[batch, spatial_1, spatial_2, depth]`
nbr_steps: an int specifying the number of doubling upsample rounds to
apply.
output_filters: an int specifying the filter count for the deconvolutions
name: a string
reuse: a boolean
Returns:
a `Tensor` with shape `[batch, spatial * (2**nbr_steps), output_filters]` or
`[batch, spatial_1 * (2**nbr_steps), spatial_2 * (2**nbr_steps),
output_filters]`
"""
with tf.variable_scope(
name, default_name="deconv_stride2_multistep", values=[x], reuse=reuse):
def deconv1d(cur, i):
cur_shape = shape_list(cur)
thicker = conv(
cur,
output_filters * 2, (1, 1),
padding="SAME",
activation=tf.nn.relu,
name="deconv1d" + str(i))
return tf.reshape(thicker,
[cur_shape[0], cur_shape[1] * 2, 1, output_filters])
def deconv2d(cur, i):
thicker = conv(
cur,
output_filters * 4, (1, 1),
padding="SAME",
activation=tf.nn.relu,
name="deconv2d" + str(i))
return tf.depth_to_space(thicker, 2)
cur = x
for i in xrange(nbr_steps):
if cur.get_shape()[2] == 1:
cur = deconv1d(cur, i)
else:
cur_dim = shape_list(cur)[2]
if isinstance(cur_dim, int):
if cur_dim == 1:
cur = deconv1d(cur, i)
else:
cur = deconv2d(cur, i)
else:
cur = tf.cond(
tf.equal(cur_dim, 1),
lambda idx=i: deconv1d(cur, idx),
lambda idx=i: deconv2d(cur, idx))
return cur
评论列表
文章目录