updater.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:Semantic-Segmentation-using-Adversarial-Networks 作者: oyam 项目源码 文件源码
def forward(self, batch):
        label_onehot_batch = [self._onehot_encode(pair[1]) for pair in batch]

        input_img, ground_truth = self.converter(batch, self.device)
        ground_truth_onehot = self.converter(label_onehot_batch, self.device)
        input_img = Variable(input_img, volatile=not self.gen.train)
        ground_truth = Variable(ground_truth, volatile=not self.gen.train)
        ground_truth_onehot = Variable(ground_truth_onehot, volatile=not self.gen.train)

        x_real = self._make_dis_input(input_img, ground_truth_onehot)
        y_real = self.dis(x_real)

        pred_label_map = self.gen(input_img)
        x_fake = self._make_dis_input(input_img, F.softmax(pred_label_map))
        y_fake = self.dis(x_fake)

        self.y_fake = y_fake
        self.y_real = y_real
        self.pred_label_map = pred_label_map
        self.ground_truth = ground_truth
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号