def test_combiner_d2_4():
""" 1 point(s) """
torch.manual_seed(1)
combiner = MLPCombinerNetwork(6)
head_feat = ag.Variable(torch.randn(1, 6))
modifier_feat = ag.Variable(torch.randn(1, 6))
combined = combiner(head_feat, modifier_feat)
combined_list = combined.view(-1).data.tolist()
true_out = [ -0.4897, 0.4484, -0.0591, 0.1778, 0.4223, -0.0940 ]
check_tensor_correctness([(combined_list, true_out)])
# ===-------------------------------------------------------------------------------------------===
# Section 3 tests
# ===-------------------------------------------------------------------------------------------===
test_parser.py 文件源码
python
阅读 30
收藏 0
点赞 0
评论 0
评论列表
文章目录