def classifier(self, xs):
"""
classify an image (or a batch of images)
:param xs: a batch of scaled vectors of pixels from an image
:return: a batch of the corresponding class labels (as one-hots)
"""
# use the trained model q(y|x) = categorical(alpha(x))
# compute all class probabilities for the image(s)
alpha = self.encoder_y.forward(xs)
# get the index (digit) that corresponds to
# the maximum predicted class probability
res, ind = torch.topk(alpha, 1)
# convert the digit(s) to one-hot tensor(s)
ys = Variable(torch.zeros(alpha.size()))
ys = ys.scatter_(1, ind, 1.0)
return ys
评论列表
文章目录