model.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:densecap-tensorflow 作者: rampage644 项目源码 文件源码
def generate_batches(positive_batch, negative_batch, batch_size):
    positive_boxes, positive_scores, positive_labels = positive_batch
    negative_boxes, negative_scores, negative_labels = negative_batch

    half_batch = batch_size // 2

    pos_batch = np.concatenate([positive_boxes, positive_scores, positive_labels], axis=1)
    neg_batch = np.concatenate([negative_boxes, negative_scores, negative_labels], axis=1)

    np.random.shuffle(pos_batch)
    np.random.shuffle(neg_batch)

    pos_batch = pos_batch[:half_batch]
    pad_size = half_batch - len(pos_batch)
    pos_batch = np.concatenate([pos_batch, neg_batch[:pad_size]])
    neg_batch = neg_batch[pad_size:pad_size+half_batch]

    return (
        np.split(pos_batch, [4, 6], axis=1),
        np.split(neg_batch, [4, 6], axis=1)
    )
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号