def errors_top_x(self, y, num_top=5):
if y.ndim != self.y_pred.ndim:
raise TypeError('y should have the same shape as self.y_pred',
('y', y.type, 'y_pred', self.y_pred.type))
if num_top != 5: print('val errors from top %d' % num_top) ############TOP 5 VERSION##########
# check if y is of the correct datatype
if y.dtype.startswith('int'):
# the T.neq operator returns a vector of 0s and 1s, where 1
# represents a mistake in prediction
y_pred_top_x = T.argsort(self.p_y_given_x, axis=1)[:, -num_top:]
y_top_x = y.reshape((y.shape[0], 1)).repeat(num_top, axis=1)
return T.mean(T.min(T.neq(y_pred_top_x, y_top_x).astype('int8'), axis=1))
else:
raise NotImplementedError()
评论列表
文章目录