def run(title, base_batch_size, base_labeled_batch_size, base_lr, n_labels, data_seed, **kwargs):
LOG.info('run title: %s', title)
ngpu = torch.cuda.device_count()
adapted_args = {
'batch_size': base_batch_size * ngpu,
'labeled_batch_size': base_labeled_batch_size * ngpu,
'lr': base_lr * ngpu,
'labels': 'data-local/labels/cifar10/{}_balanced_labels/{:02d}.txt'.format(n_labels, data_seed),
}
context = RunContext(__file__, "{}_{}".format(n_labels, data_seed))
main.args = parse_dict_args(**adapted_args, **kwargs)
main.main(context)
评论列表
文章目录