seq2seq_test.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:DeepLearning_VirtualReality_BigData_Project 作者: rashmitripathi 项目源码 文件源码
def testDynamicAttentionDecoder1(self):
    with self.test_session() as sess:
      with variable_scope.variable_scope(
          "root", initializer=init_ops.constant_initializer(0.5)):
        cell = core_rnn_cell_impl.GRUCell(2)
        inp = constant_op.constant(0.5, shape=[2, 2, 2])
        enc_outputs, enc_state = rnn.dynamic_rnn(
            cell, inp, dtype=dtypes.float32)
        attn_states = enc_outputs
        dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3
        dec, mem = seq2seq_lib.attention_decoder(
            dec_inp, enc_state, attn_states, cell, output_size=4)
        sess.run([variables.global_variables_initializer()])
        res = sess.run(dec)
        self.assertEqual(3, len(res))
        self.assertEqual((2, 4), res[0].shape)

        res = sess.run([mem])
        self.assertEqual((2, 2), res[0].shape)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号