def encode(self, inputs):
inputs = tf.image.resize_images(
images=inputs,
size=[self.params["resize_height"], self.params["resize_width"]],
method=tf.image.ResizeMethod.BILINEAR)
outputs, _ = inception_v3_base(tf.to_float(inputs))
output_shape = outputs.get_shape() #pylint: disable=E1101
shape_list = output_shape.as_list()
# Take attentin over output elemnts in width and height dimension:
# Shape: [B, W*H, ...]
outputs_flat = tf.reshape(outputs, [shape_list[0], -1, shape_list[-1]])
# Final state is the pooled output
# Shape: [B, W*H*...]
final_state = tf.contrib.slim.avg_pool2d(
outputs, output_shape[1:3], padding="VALID", scope="pool")
final_state = tf.contrib.slim.flatten(outputs, scope="flatten")
return EncoderOutput(
outputs=outputs_flat,
final_state=final_state,
attention_values=outputs_flat,
attention_values_length=tf.shape(outputs_flat)[1])
评论列表
文章目录