test_minibatching.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:mean-teacher 作者: CuriousAI 项目源码 文件源码
def test_batches_from_two_sets():
    data1 = np.array(['a', 'b'])
    data2 = np.array(['c', 'd', 'e'])

    batch_generator = combine_batches(
        eternal_batches(data1, batch_size=1),
        eternal_batches(data2, batch_size=2)
    )

    first_six_batches = list(islice(batch_generator, 6))
    assert [len(batch) for batch in first_six_batches] == [3, 3, 3, 3, 3, 3]

    batch_portions1 = [batch[:1] for batch in first_six_batches]
    batch_portions2 = [batch[1:] for batch in first_six_batches]

    returned1 = np.concatenate(batch_portions1)
    returned2 = np.concatenate(batch_portions2)

    epochs1 = np.split(returned1, 3)
    epochs2 = np.split(returned2, 4)

    assert all(sorted(items) == ['a', 'b'] for items in epochs1)
    assert all(sorted(items) == ['c', 'd', 'e'] for items in epochs2)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号