test_parser.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:DeepDependencyParsingProblemSet 作者: rguthrie3 项目源码 文件源码
def test_lstm_combiner_d4_3():
    """ 1 point(s) """

    torch.manual_seed(1)
    combiner = LSTMCombinerNetwork(TEST_EMBEDDING_DIM, 1, 0.0)
    head_feat = ag.Variable(torch.randn(1, TEST_EMBEDDING_DIM))
    modifier_feat = ag.Variable(torch.randn(1, TEST_EMBEDDING_DIM))

    # Do the combination a few times to make sure they implemented the sequential
    # part right
    combined = combiner(head_feat, modifier_feat)
    combined = combiner(head_feat, modifier_feat)
    combined = combiner(head_feat, modifier_feat)

    combined_list = combined.view(-1).data.tolist()

    true_out = [ 0.0873, -0.1837, 0.1975, -0.1166 ]
    check_tensor_correctness([(combined_list, true_out)])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号