test_lstm.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:document-qa 作者: allenai 项目源码 文件源码
def test_forget_bias(self):
        """
        Make sure the forget bias is only being applied to the forget gate
        """
        batches = 1
        num_units = 5
        num_inputs = 5

        hidden_size = (batches, num_units)
        input_size = (batches, num_inputs)

        inputs = tf.placeholder(dtype='float32', shape=input_size)
        h = tf.placeholder(dtype='float32', shape=hidden_size)
        with tf.variable_scope("test_bias"):
            i_t, j_t, f_t, o_t = _compute_gates(inputs, h, 4 * num_units, 1,
                                                init_ops.zeros_initializer(), init_ops.zeros_initializer())
        gates = [i_t, j_t, f_t, o_t]

        sess = tf.Session()
        sess.run(tf.global_variables_initializer())

        # Make sure the bias is ONLY getting applied to the forget gate
        [i,j,f,o] = sess.run(gates, feed_dict={inputs: np.zeros(input_size), h: np.ones(hidden_size)})
        self.assertTrue(np.allclose(f, np.ones(f.shape), rtol=0))
        for x in [i,j,o]:
            self.assertTrue(np.allclose(x, np.zeros(x.shape), rtol=0))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号