def test_lstm_combiner_d4_3():
""" 1 point(s) """
torch.manual_seed(1)
combiner = LSTMCombinerNetwork(TEST_EMBEDDING_DIM, 1, 0.0)
head_feat = ag.Variable(torch.randn(1, TEST_EMBEDDING_DIM))
modifier_feat = ag.Variable(torch.randn(1, TEST_EMBEDDING_DIM))
# Do the combination a few times to make sure they implemented the sequential
# part right
combined = combiner(head_feat, modifier_feat)
combined = combiner(head_feat, modifier_feat)
combined = combiner(head_feat, modifier_feat)
combined_list = combined.view(-1).data.tolist()
true_out = [ 0.0873, -0.1837, 0.1975, -0.1166 ]
check_tensor_correctness([(combined_list, true_out)])
test_parser.py 文件源码
python
阅读 28
收藏 0
点赞 0
评论 0
评论列表
文章目录