def get_loader(src_path, trg_path, src_word2id, trg_word2id, batch_size=100):
"""Returns data loader for custom dataset.
Args:
src_path: txt file path for source domain.
trg_path: txt file path for target domain.
src_word2id: word-to-id dictionary (source domain).
trg_word2id: word-to-id dictionary (target domain).
batch_size: mini-batch size.
Returns:
data_loader: data loader for custom dataset.
"""
# build a custom dataset
dataset = Dataset(src_path, trg_path, src_word2id, trg_word2id)
# data loader for custome dataset
# this will return (src_seqs, src_lengths, trg_seqs, trg_lengths) for each iteration
# please see collate_fn for details
data_loader = torch.utils.data.DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=True,
collate_fn=collate_fn)
return data_loader
评论列表
文章目录