classification.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:gmn 作者: sbos 项目源码 文件源码
def one_shot_classification(test_data, num_shots, num_classes, compute_similarities, k_neighbours=1,
                            num_episodes=10000):
    data_shape = np.prod(test_data[0][0].shape)
    episode_length = num_shots * num_classes + 1
    batch = np.zeros([num_classes, episode_length, data_shape], dtype=np.float32)

    accuracy = 0.
    votes = np.zeros(num_classes)

    for episode in xrange(num_episodes):
        classes = np.random.choice(test_data.shape[0], num_classes, False)
        classes_idx = np.repeat(classes[:, np.newaxis], num_shots, 1).flatten()
        idx = []
        for k in xrange(num_classes):
            idx.append(np.random.choice(test_data.shape[1], num_shots + 1, False))
        idx = np.vstack(idx)
        y = np.repeat(np.arange(num_classes)[:, np.newaxis], num_shots, 1).flatten()

        # print batch[:, :-1, :].shape, idx[:, :-1].flatten().shape
        batch[:, :-1, :] = test_data[classes_idx, idx[:, :-1].flatten(), :]
        batch[:,  -1, :] = test_data[classes,     idx[:,  -1].flatten(), :]

        # np.true_divide(batch, 255., out=batch, casting='unsafe')

        # sim[i, j] -- similarity between batch[i, -1] and batch[i, j]
        sim = compute_similarities(batch)

        for k in xrange(num_classes):
            votes[:] = 0.
            nearest = sim[k].argsort()[-k_neighbours:]
            for j in nearest:
                votes[y[j]] += sim[k, j]
            y_hat = votes.argmax()
            if y_hat == k:
                accuracy += 1

        status = 'episode: %d, accuracy: %f' % (episode, accuracy / num_classes / (episode + 1))
        sys.stdout.write('\r' + status)
        sys.stdout.flush()

    return accuracy / num_episodes / num_classes
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号