test_tf_qrnn_forward.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:tensorflow_qrnn 作者: icoxfog417 项目源码 文件源码
def test_qrnn_convolution(self):
        batch_size = 100
        sentence_length = 5
        word_size = 10
        size = 5
        data = self.create_test_data(batch_size, sentence_length, word_size)

        with tf.Graph().as_default() as q_conv:
            qrnn = QRNN(in_size=word_size, size=size, conv_size=3)
            X = tf.placeholder(tf.float32, [batch_size, sentence_length, word_size])
            forward_graph = qrnn.forward(X)

            with tf.Session() as sess:
                sess.run(tf.global_variables_initializer())
                hidden = sess.run(forward_graph, feed_dict={X: data})
                self.assertEqual((batch_size, size), hidden.shape)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号