def test_feature_extraction_d2_2():
""" 0.5 point(s) """
global test_sent, gold, word_to_ix, vocab
torch.manual_seed(1)
feat_extractor = SimpleFeatureExtractor()
embedder = VanillaWordEmbeddingLookup(word_to_ix, TEST_EMBEDDING_DIM)
combiner = DummyCombiner()
embeds = embedder(test_sent)
state = ParserState(test_sent, embeds, combiner)
state.shift()
state.shift()
feats = feat_extractor.get_features(state)
feats_list = make_list(feats)
true = ([ -1.8661, 1.4146, -1.8781, -0.4674 ], [ -0.9596, 0.5489, -0.9901, -0.3826 ], [ 0.5237, 0.0004, -1.2039, 3.5283 ])
pairs = zip(feats_list, true)
check_tensor_correctness(pairs)
test_parser.py 文件源码
python
阅读 46
收藏 0
点赞 0
评论 0
评论列表
文章目录