def decompress_step(source, hparams, first_relu, is_2d, name):
"""Decompression function."""
with tf.variable_scope(name):
shape = common_layers.shape_list(source)
multiplier = 4 if is_2d else 2
kernel = (1, 1) if is_2d else (1, 1)
thicker = common_layers.conv_block(
source, hparams.hidden_size * multiplier, [((1, 1), kernel)],
first_relu=first_relu, name="decompress_conv")
if is_2d:
return tf.depth_to_space(thicker, 2)
return tf.reshape(thicker, [shape[0], shape[1] * 2, 1, hparams.hidden_size])
评论列表
文章目录