test_parser.py 文件源码

python
阅读 46 收藏 0 点赞 0 评论 0

项目:DeepDependencyParsingProblemSet 作者: rguthrie3 项目源码 文件源码
def test_feature_extraction_d2_2():
    """ 0.5 point(s) """

    global test_sent, gold, word_to_ix, vocab
    torch.manual_seed(1)

    feat_extractor = SimpleFeatureExtractor()
    embedder = VanillaWordEmbeddingLookup(word_to_ix, TEST_EMBEDDING_DIM)
    combiner = DummyCombiner()
    embeds = embedder(test_sent)
    state = ParserState(test_sent, embeds, combiner)

    state.shift()
    state.shift()

    feats = feat_extractor.get_features(state)
    feats_list = make_list(feats)
    true = ([ -1.8661, 1.4146, -1.8781, -0.4674 ], [ -0.9596, 0.5489, -0.9901, -0.3826 ], [ 0.5237, 0.0004, -1.2039, 3.5283 ])
    pairs = zip(feats_list, true)
    check_tensor_correctness(pairs)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号