image_encoder.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:seq2seq 作者: google 项目源码 文件源码
def encode(self, inputs):
    inputs = tf.image.resize_images(
        images=inputs,
        size=[self.params["resize_height"], self.params["resize_width"]],
        method=tf.image.ResizeMethod.BILINEAR)

    outputs, _ = inception_v3_base(tf.to_float(inputs))
    output_shape = outputs.get_shape()  #pylint: disable=E1101
    shape_list = output_shape.as_list()

    # Take attentin over output elemnts in width and height dimension:
    # Shape: [B, W*H, ...]
    outputs_flat = tf.reshape(outputs, [shape_list[0], -1, shape_list[-1]])

    # Final state is the pooled output
    # Shape: [B, W*H*...]
    final_state = tf.contrib.slim.avg_pool2d(
        outputs, output_shape[1:3], padding="VALID", scope="pool")
    final_state = tf.contrib.slim.flatten(outputs, scope="flatten")

    return EncoderOutput(
        outputs=outputs_flat,
        final_state=final_state,
        attention_values=outputs_flat,
        attention_values_length=tf.shape(outputs_flat)[1])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号