test_nnbase.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:DiscourseSenser 作者: WladimirSidorenko 项目源码 文件源码
def test_predict_1(self):
        x = np.array([[0, 1, 2, 3, 4], [[0, 1, 2, 3, 4]]])
        y = np.array([0, 0, 0, 0])

        def _predict_func_mock(*args, **kwargs):
            return np.array([[0, 0, 1, 0],
                             [0, 0, 0, 1],
                             [0, 0, 1, 0],
                             [1, 0, 0, 0],
                             ])

        with patch.multiple(self.nnbs,
                            _predict_func=_predict_func_mock,
                            get_test_w_emb_i=None,
                            _init_wemb_funcs=MagicMock(),
                            _rel2x=MagicMock(return_value=x)):
            ret = np.array([[0] * 4])
            self.nnbs.predict(None, (None, None), ret, 0)
            assert np.allclose(ret[0], y)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号