def bdnn_prediction(bdnn_batch_size, logits, threshold=th):
result = np.zeros((bdnn_batch_size, 1))
indx = np.arange(bdnn_batch_size) + 1
indx = indx.reshape((bdnn_batch_size, 1))
indx = utils.bdnn_transform(indx, w, u)
indx = indx[w:(bdnn_batch_size-w), :]
indx_list = np.arange(w, bdnn_batch_size - w)
for i in indx_list:
indx_temp = np.where((indx-1) == i)
pred = logits[indx_temp]
pred = np.sum(pred)/pred.shape[0]
result[i] = pred
result = np.trim_zeros(result)
result = result >= threshold
return result.astype(int)
评论列表
文章目录