data_loader.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:multiNLI_encoder 作者: easonnie 项目源码 文件源码
def load_data(data_root, embd_file, reseversed=True, batch_sizes=(32, 32, 32), device=-1):
    if reseversed:
        testl_field = RParsedTextLField()
    else:
        testl_field = ParsedTextLField()

    transitions_field = datasets.snli.ShiftReduceField()
    y_field = data.Field(sequential=False)

    train, dev, test = datasets.SNLI.splits(testl_field, y_field, transitions_field, root=data_root)
    testl_field.build_vocab(train, dev, test)
    y_field.build_vocab(train)

    testl_field.vocab.vectors = torch.load(embd_file)

    train_iter, dev_iter, test_iter = data.Iterator.splits(
        (train, dev, test), batch_sizes=batch_sizes, device=device, shuffle=False)

    return train_iter, dev_iter, test_iter, testl_field.vocab.vectors
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号