test_thin_stack.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:thinstack-rl 作者: hans 项目源码 文件源码
def test_basic_ff(self):
        self._make_stack(seq_length=5)

        X = np.array([
            [3, 1,  2],
            [3, 2,  4]
        ], dtype=np.int32).T

        transitions = np.array([
            [0, 0, 0, 1, 1],
            [0, 0, 1, 0, 1]
        ], dtype=np.float32)

        num_transitions = np.array([4, 4], dtype=np.int32)

        expected = np.array([[ 3.,  3.,  3.],
                             [ 3.,  3.,  3.],
                             [ 1.,  1.,  1.],
                             [ 2.,  2.,  2.],
                             [ 2.,  2.,  2.],
                             [ 5.,  5.,  5.],
                             [ 3.,  3.,  3.],
                             [ 4.,  4.,  4.],
                             [ 6.,  6.,  6.],
                             [ 9.,  9.,  9.]])

        # Run twice to make sure first state is properly erased
        with self.test_session() as s:
            s.run(tf.initialize_variables(tf.trainable_variables()))
            ts = self.stack

            feed = {ts.transitions[t]: transitions[:, t]
                    for t in range(self.seq_length)}
            feed[ts.buff] = X
            feed[ts.num_transitions] = num_transitions

            for _ in range(2):
                ts.reset(s)

                ret = s.run(ts.stack, feed)
                np.testing.assert_almost_equal(ret, expected)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号