def test_stratified_batches():
data = np.array([('a', -1), ('b', 0), ('c', 1), ('d', -1), ('e', -1)],
dtype=[('x', np.str_, 8), ('y', np.int32)])
assert list(data['x']) == ['a', 'b', 'c', 'd', 'e']
assert list(data['y']) == [-1, 0, 1, -1, -1]
batch_generator = training_batches(data, batch_size=3, n_labeled_per_batch=1)
first_ten_batches = list(islice(batch_generator, 10))
labeled_batch_portions = [batch[:1] for batch in first_ten_batches]
unlabeled_batch_portions = [batch[1:] for batch in first_ten_batches]
labeled_epochs = np.split(np.concatenate(labeled_batch_portions), 5)
unlabeled_epochs = np.split(np.concatenate(unlabeled_batch_portions), 4)
assert ([sorted(items['x'].tolist()) for items in labeled_epochs] ==
[['b', 'c']] * 5)
assert ([sorted(items['y'].tolist()) for items in labeled_epochs] ==
[[0, 1]] * 5)
assert ([sorted(items['x'].tolist()) for items in unlabeled_epochs] ==
[['a', 'b', 'c', 'd', 'e']] * 4)
assert ([sorted(items['y'].tolist()) for items in unlabeled_epochs] ==
[[-1, -1, -1, -1, -1]] * 4)
评论列表
文章目录