def testRNNDecoder(self):
with self.test_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
_, enc_state = core_rnn.static_rnn(
core_rnn_cell_impl.GRUCell(2), inp, dtype=dtypes.float32)
dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3
cell = core_rnn_cell_impl.OutputProjectionWrapper(
core_rnn_cell_impl.GRUCell(2), 4)
dec, mem = seq2seq_lib.rnn_decoder(dec_inp, enc_state, cell)
sess.run([variables.global_variables_initializer()])
res = sess.run(dec)
self.assertEqual(3, len(res))
self.assertEqual((2, 4), res[0].shape)
res = sess.run([mem])
self.assertEqual((2, 2), res[0].shape)
seq2seq_test.py 文件源码
python
阅读 20
收藏 0
点赞 0
评论 0
评论列表
文章目录