def generate(self):
if self.shuffle:
idx = np.random.permutation(np.arange(self.vocab_size))
idx = torch.LongTensor(idx)
src = self.src[idx]
target = self.target[idx]
else:
src = self.src
target = self.target
for i in range(0, self.vocab_size, self.batch_size):
b_idx = i
e_idx = i + self.batch_size
batch_x = src[b_idx:e_idx].contiguous()
batch_y = target[b_idx:e_idx].contiguous()
batch_x = Variable(batch_x)
batch_y = Variable(batch_y)
yield batch_x, batch_y
评论列表
文章目录