def test_pretrained_embeddings_d4_2():
""" 0.5 point(s) """
torch.manual_seed(1)
word_to_ix = { "interest": 0, "rate": 1, "swap": 2 }
pretrained = { "interest": [ -1.4, 2.6, 3.5 ], "swap": [ 1.6, 5.7, 3.2 ] }
embedder = VanillaWordEmbeddingLookup(word_to_ix, 3)
initialize_with_pretrained(pretrained, embedder)
embeddings = embedder.word_embeddings.weight.data
pairs = []
true_rate_embed = [ -2.2820, 0.5237, 0.0004 ]
pairs.append((embeddings[word_to_ix["interest"]].tolist(), pretrained["interest"]))
pairs.append((embeddings[word_to_ix["rate"]].tolist(), true_rate_embed))
pairs.append((embeddings[word_to_ix["swap"]].tolist(), pretrained["swap"]))
check_tensor_correctness(pairs)
test_parser.py 文件源码
python
阅读 32
收藏 0
点赞 0
评论 0
评论列表
文章目录