def train_net(solver_prototxt, roidb, output_dir, nccl_uid, gpus, rank,
queue, bbox_means, bbox_stds, pretrained_model=None, max_iters=40000):
"""Train a Fast R-CNN network."""
caffe.set_mode_gpu()
caffe.set_device(gpus[rank])
caffe.set_solver_count(len(gpus))
caffe.set_solver_rank(rank)
caffe.set_multiprocess(True)
caffe.set_random_seed(cfg.RNG_SEED)
sw = SolverWrapper(solver_prototxt, roidb, output_dir, nccl_uid,
rank, bbox_means, bbox_stds, pretrained_model=pretrained_model)
model_paths = sw.train_model(max_iters)
if rank==0:
queue.put(model_paths)
train.py 文件源码
python
阅读 24
收藏 0
点赞 0
评论 0
评论列表
文章目录