decoder_test.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:seq2seq 作者: google 项目源码 文件源码
def create_decoder(self, helper, mode):
    attention_fn = AttentionLayerDot(
        params={"num_units": self.attention_dim},
        mode=tf.contrib.learn.ModeKeys.TRAIN)
    attention_values = tf.convert_to_tensor(
        np.random.randn(self.batch_size, self.input_seq_len, 32),
        dtype=tf.float32)
    attention_keys = tf.convert_to_tensor(
        np.random.randn(self.batch_size, self.input_seq_len, 32),
        dtype=tf.float32)
    params = AttentionDecoder.default_params()
    params["max_decode_length"] = self.max_decode_length
    return AttentionDecoder(
        params=params,
        mode=mode,
        vocab_size=self.vocab_size,
        attention_keys=attention_keys,
        attention_values=attention_values,
        attention_values_length=np.arange(self.batch_size) + 1,
        attention_fn=attention_fn)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号