def forward(self, batch):
label_onehot_batch = [self._onehot_encode(pair[1]) for pair in batch]
input_img, ground_truth = self.converter(batch, self.device)
ground_truth_onehot = self.converter(label_onehot_batch, self.device)
input_img = Variable(input_img, volatile=not self.gen.train)
ground_truth = Variable(ground_truth, volatile=not self.gen.train)
ground_truth_onehot = Variable(ground_truth_onehot, volatile=not self.gen.train)
x_real = self._make_dis_input(input_img, ground_truth_onehot)
y_real = self.dis(x_real)
pred_label_map = self.gen(input_img)
x_fake = self._make_dis_input(input_img, F.softmax(pred_label_map))
y_fake = self.dis(x_fake)
self.y_fake = y_fake
self.y_real = y_real
self.pred_label_map = pred_label_map
self.ground_truth = ground_truth
updater.py 文件源码
python
阅读 31
收藏 0
点赞 0
评论 0
评论列表
文章目录