predict.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:minc_keras 作者: tfunck 项目源码 文件源码
def save_image(X_test, X_predict, Y_test ,output_fn, slices=None, nslices=25 ):
    '''
        Writes X_test, X_predict, and Y_test to a single png image. Unless specific slices are given, function will write <nslices> evenly spaced slices. 

        args:
        X_test -- slice of values input to model
        X_predict -- slice of predicted values based on X_test
        Y_test -- slice of predicted values
        output_fn -- filename of output png file
        slices -- axial slices to save to png file, None by default
        nslices -- number of evenly spaced slices to save to png

        returns: 0
    '''
    #if no slices are defined by user, set slices to evenly sampled slices along entire number of slices in 3d image volume
    if slices == None : slices = range(0,  X_test.shape[0], int(X_test.shape[0]/nslices) )


    #set number of rows and columns in output image. currently, sqrt() means that the image will be a square, but this could be changed if a more vertical orientation is prefered
    ncol=int(np.sqrt(nslices))
    nrow=ncol

    fig = plt.figure(1 )

    #using gridspec because it seems to give a bit more control over the spacing of the images. define a nrow x ncol grid
    outer_grid = gridspec.GridSpec(nrow, ncol,wspace=0.0, hspace=0.0 )

    slice_index=0 #index value for <slices>
    #iterate over columns and rows:
    for col in range(ncol):
        for row in range(nrow) :
            s=slices[slice_index]
            i=col*nrow+row 

            #couldn't get inner grid to work properly, so commented out for now. 
            #in theory, should be able to get rid of all white spacing with it
            #inner_grid = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=outer_grid[i], wspace=0.0, hspace=0.0)

            #normalize the three input numpy arrays. normalizing them independently is necessary so that they all have the same scale
            A=normalize(X_test[s])
            B=normalize(X_predict[s])
            C=normalize(Y_test[s])
            ABC = np.concatenate([A,B,C], axis=1)

            #use imwshow to display all three images
            plt.subplot(outer_grid[i] )
            plt.imshow(ABC)
            plt.axis('off')
            plt.subplots_adjust(hspace=0.0, wspace=0.0)

            slice_index+=1

    outer_grid.tight_layout(fig,  pad=0, h_pad=0, w_pad=0 ) 
    plt.savefig(output_fn, dpi=750)  
    return 0
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号