def load_data(data_root, embd_file, reseversed=True, batch_sizes=(32, 32, 32), device=-1):
if reseversed:
testl_field = RParsedTextLField()
else:
testl_field = ParsedTextLField()
transitions_field = datasets.snli.ShiftReduceField()
y_field = data.Field(sequential=False)
train, dev, test = datasets.SNLI.splits(testl_field, y_field, transitions_field, root=data_root)
testl_field.build_vocab(train, dev, test)
y_field.build_vocab(train)
testl_field.vocab.vectors = torch.load(embd_file)
train_iter, dev_iter, test_iter = data.Iterator.splits(
(train, dev, test), batch_sizes=batch_sizes, device=device, shuffle=False)
return train_iter, dev_iter, test_iter, testl_field.vocab.vectors
评论列表
文章目录