def combine_label_batch(num0, num1, numt=0, order='01'):
assert order=='01' or order=='10'
label_batch_0 = np.tile((1,0,0),(num0,1))
label_batch_1 = np.tile((0,1,0),(num1,1))
label_batch_t = np.tile((0,0,1),(numt,1))
if order == '01':
label_batch_all = np.row_stack((label_batch_0, label_batch_1, label_batch_t))
else:
label_batch_all = np.row_stack((label_batch_1, label_batch_0, label_batch_t))
label_batch_all = label_batch_all.astype('float32')
return label_batch_all
评论列表
文章目录