test_parser.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:DeepDependencyParsingProblemSet 作者: rguthrie3 项目源码 文件源码
def test_parse_logic_d3_1():
    """ 0.5 point(s) """

    global test_sent, gold, word_to_ix, vocab
    torch.manual_seed(1)

    feat_extract = SimpleFeatureExtractor()
    word_embed = VanillaWordEmbeddingLookup(word_to_ix, TEST_EMBEDDING_DIM)
    act_chooser = ActionChooserNetwork(TEST_EMBEDDING_DIM * NUM_FEATURES)
    combiner = MLPCombinerNetwork(TEST_EMBEDDING_DIM)

    parser = TransitionParser(feat_extract, word_embed, act_chooser, combiner)
    output, dep_graph, actions_done = parser(test_sent[:-1], gold)

    assert len(output) == 15 # Made the right number of decisions

    # check one of the outputs
    checked_out = output[10].view(-1).data.tolist()
    true_out = [ -1.4737, -1.0875, -0.8350 ]
    check_tensor_correctness([(true_out, checked_out)])

    true_dep_graph = dependency_graph_from_oracle(test_sent, gold)
    assert true_dep_graph == dep_graph
    assert actions_done == [ 0, 0, 1, 0, 1, 0, 0, 1, 2, 0, 0, 0, 1, 1, 2 ]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号