unet_segmentation_no_db_example.py 文件源码

python
阅读 49 收藏 0 点赞 0 评论 0

项目:peters-stuff 作者: peterneher 项目源码 文件源码
def train_network(solver_file, num_classes, batch_size, num_iterations, use_gpu=True) :

    if use_gpu :
        caffe.set_mode_gpu()
    else :
        caffe.set_mode_cpu()

    solver = caffe.get_solver(solver_file)
    solver.net.blobs['data'].reshape(batch_size, solver.net.blobs['data'].shape[1], solver.net.blobs['data'].shape[2], solver.net.blobs['data'].shape[3])
    solver.net.blobs['target'].reshape(batch_size, solver.net.blobs['target'].shape[1], solver.net.blobs['target'].shape[2], solver.net.blobs['target'].shape[3])
    solver.net.reshape()

    for i in range(num_iterations):

        data, target = get_data(batch_size, numclasses=num_classes)

        solver.net.blobs['data'].data[...] = data
        solver.net.blobs['target'].data[...] = target
        solver.step(1)
        output = solver.net.blobs['argmax'].data[...]

    fig, sub = plt.subplots(ncols=3, figsize=(15, 5))
    sub[0].set_title('Input')
    sub[0].imshow(data[0, 0, :, :])
    sub[1].set_title('Ground Truth')
    sub[1].imshow(np.argmax(target[0, :, :, :], axis=0))
    sub[2].set_title('Segmentation')
    sub[2].imshow(output[0, 0, :, :])
    plt.show()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号