def train_network(solver_file, num_classes, batch_size, num_iterations, use_gpu=True) :
if use_gpu :
caffe.set_mode_gpu()
else :
caffe.set_mode_cpu()
solver = caffe.get_solver(solver_file)
solver.net.blobs['data'].reshape(batch_size, solver.net.blobs['data'].shape[1], solver.net.blobs['data'].shape[2], solver.net.blobs['data'].shape[3])
solver.net.blobs['target'].reshape(batch_size, solver.net.blobs['target'].shape[1], solver.net.blobs['target'].shape[2], solver.net.blobs['target'].shape[3])
solver.net.reshape()
for i in range(num_iterations):
data, target = get_data(batch_size, numclasses=num_classes)
solver.net.blobs['data'].data[...] = data
solver.net.blobs['target'].data[...] = target
solver.step(1)
output = solver.net.blobs['argmax'].data[...]
fig, sub = plt.subplots(ncols=3, figsize=(15, 5))
sub[0].set_title('Input')
sub[0].imshow(data[0, 0, :, :])
sub[1].set_title('Ground Truth')
sub[1].imshow(np.argmax(target[0, :, :, :], axis=0))
sub[2].set_title('Segmentation')
sub[2].imshow(output[0, 0, :, :])
plt.show()
unet_segmentation_no_db_example.py 文件源码
python
阅读 49
收藏 0
点赞 0
评论 0
评论列表
文章目录