test_minibatching.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:mean-teacher 作者: CuriousAI 项目源码 文件源码
def test_stratified_batches():
    data = np.array([('a', -1), ('b', 0), ('c', 1), ('d', -1), ('e', -1)],
                    dtype=[('x', np.str_, 8), ('y', np.int32)])

    assert list(data['x']) == ['a', 'b', 'c', 'd', 'e']
    assert list(data['y']) == [-1, 0, 1, -1, -1]

    batch_generator = training_batches(data, batch_size=3, n_labeled_per_batch=1)

    first_ten_batches = list(islice(batch_generator, 10))

    labeled_batch_portions = [batch[:1] for batch in first_ten_batches]
    unlabeled_batch_portions = [batch[1:] for batch in first_ten_batches]

    labeled_epochs = np.split(np.concatenate(labeled_batch_portions), 5)
    unlabeled_epochs = np.split(np.concatenate(unlabeled_batch_portions), 4)

    assert ([sorted(items['x'].tolist()) for items in labeled_epochs] ==
            [['b', 'c']] * 5)
    assert ([sorted(items['y'].tolist()) for items in labeled_epochs] ==
            [[0, 1]] * 5)
    assert ([sorted(items['x'].tolist()) for items in unlabeled_epochs] ==
            [['a', 'b', 'c', 'd', 'e']] * 4)
    assert ([sorted(items['y'].tolist()) for items in unlabeled_epochs] ==
            [[-1, -1, -1, -1, -1]] * 4)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号