adgm.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:adgm 作者: musyoku 项目源码 文件源码
def sample_ax_label(self, a, x, argmax=True, test=False):
        a = self.to_variable(a)
        x = self.to_variable(x)
        batchsize = x.data.shape[0]
        y_distribution = self.q_y_ax(a, x, test=test, softmax=True).data
        n_labels = y_distribution.shape[1]
        if self.gpu_enabled:
            y_distribution = cuda.to_cpu(y_distribution)
        if argmax:
            sampled_label = np.argmax(y_distribution, axis=1)
        else:
            sampled_label = np.zeros((batchsize,), dtype=np.int32)
            labels = np.arange(n_labels)
            for b in xrange(batchsize):
                label_id = np.random.choice(labels, p=y_distribution[b])
                sampled_label[b] = 1
        return sampled_label
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号