test_parse_model.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:lang2program 作者: kelvinguu 项目源码 文件源码
def test(self, model, cases):
        sess = tf.get_default_session()
        guarantee_initialized_variables(sess)
        embeds = model.compute(model.embeds, cases)
        primitive_embeddings = RLongPrimitiveEmbeddings(6)

        # compute object embedding after applying projection
        object_projection_layer = model._object_projection_layer
        W, b = object_projection_layer.get_weights()  # shapes [10, 6] and [6]
        object_embed = np.ones(10).dot(W) + b

        assert_array_almost_equal(embeds[0],
                                  np.concatenate((np.zeros(6), primitive_embeddings['r'], primitive_embeddings[-1]))
                                  )

        assert_array_almost_equal(embeds[1],
                                  np.concatenate((np.zeros(6), np.zeros(6), primitive_embeddings['X1/1']))
                                  )

        assert_array_almost_equal(embeds[2],
                                  np.concatenate((primitive_embeddings['b'], object_embed, object_embed))
                                  )
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号