def test_qrnn_with_previous(self):
batch_size = 100
sentence_length = 5
word_size = 10
size = 5
data = self.create_test_data(batch_size, sentence_length, word_size)
with tf.Graph().as_default() as q_with_previous:
qrnn = QRNN(in_size=word_size, size=size, conv_size=2)
X = tf.placeholder(tf.float32, [batch_size, sentence_length, word_size])
forward_graph = qrnn.forward(X)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
hidden = sess.run(forward_graph, feed_dict={X: data})
self.assertEqual((batch_size, size), hidden.shape)
test_tf_qrnn_forward.py 文件源码
python
阅读 36
收藏 0
点赞 0
评论 0
评论列表
文章目录