def generate_batches(positive_batch, negative_batch, batch_size):
positive_boxes, positive_scores, positive_labels = positive_batch
negative_boxes, negative_scores, negative_labels = negative_batch
half_batch = batch_size // 2
pos_batch = np.concatenate([positive_boxes, positive_scores, positive_labels], axis=1)
neg_batch = np.concatenate([negative_boxes, negative_scores, negative_labels], axis=1)
np.random.shuffle(pos_batch)
np.random.shuffle(neg_batch)
pos_batch = pos_batch[:half_batch]
pad_size = half_batch - len(pos_batch)
pos_batch = np.concatenate([pos_batch, neg_batch[:pad_size]])
neg_batch = neg_batch[pad_size:pad_size+half_batch]
return (
np.split(pos_batch, [4, 6], axis=1),
np.split(neg_batch, [4, 6], axis=1)
)
评论列表
文章目录