OneShotMiniImageNetBuilder.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:MatchingNetworks 作者: gitabcworld 项目源码 文件源码
def run_testing_epoch(self):
        """
        Runs one testing epoch
        :param total_test_batches: Number of batches to train on
        :param sess: Session object
        :return: mean_testing_categorical_crossentropy_loss and mean_testing_accuracy
        """
        total_test_c_loss = 0.
        total_test_accuracy = 0.
        total_test_batches = len(self.test_loader)
        pbar = tqdm(enumerate(self.test_loader))
        for batch_idx, (x_support_set, y_support_set, x_target, target_y) in pbar:

                x_support_set = Variable(x_support_set).float()
                y_support_set = Variable(y_support_set,requires_grad=False).long()
                x_target = Variable(x_target.squeeze()).float()
                y_target = Variable(target_y.squeeze(),requires_grad=False).long()

                # y_support_set: Add extra dimension for the one_hot
                y_support_set = torch.unsqueeze(y_support_set, 2)
                sequence_length = y_support_set.size()[1]
                batch_size = y_support_set.size()[0]
                y_support_set_one_hot = torch.FloatTensor(batch_size, sequence_length,
                                                          self.classes_per_set).zero_()
                y_support_set_one_hot.scatter_(2, y_support_set.data, 1)
                y_support_set_one_hot = Variable(y_support_set_one_hot)

                if self.isCudaAvailable:
                    acc, c_loss_value = self.matchingNet(x_support_set.cuda(), y_support_set_one_hot.cuda(),
                                                         x_target.cuda(), y_target.cuda())
                else:
                    acc, c_loss_value = self.matchingNet(x_support_set, y_support_set_one_hot,
                                                         x_target, y_target)

                iter_out = "test_loss: {}, test_accuracy: {}".format(c_loss_value.data[0], acc.data[0])
                pbar.set_description(iter_out)
                pbar.update(1)

                total_test_c_loss += c_loss_value.data[0]
                total_test_accuracy += acc.data[0]

        total_test_c_loss = total_test_c_loss / total_test_batches
        total_test_accuracy = total_test_accuracy / total_test_batches
        return total_test_c_loss, total_test_accuracy
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号