def predict(self, outputs, targets, weights, criterion):
outputs_split = torch.split(outputs, self.batch_size, self.dim)
targets_split = torch.split(targets, self.batch_size, self.dim)
weights_split = torch.split(weights, self.batch_size, self.dim)
preds = []
loss = 0
for out_t, targ_t, w_t in zip(outputs_split, targets_split, weights_split):
preds_t, loss_t = super(MemEfficientGenerator, self).predict(
out_t, targ_t, w_t, criterion)
preds.append(preds_t)
loss += loss_t
preds = torch.cat(preds, self.dim)
return preds, loss
评论列表
文章目录