def knn_masked_data(trX,trY,missing_data_dir, input_shape, k):
raw_im_data = np.loadtxt(join(script_dir,missing_data_dir,'index.txt'),delimiter=' ',dtype=str)
raw_mask_data = np.loadtxt(join(script_dir,missing_data_dir,'index_mask.txt'),delimiter=' ',dtype=str)
# Using 'brute' method since we only want to do one query per classifier
# so this will be quicker as it avoids overhead of creating a search tree
knn_m = KNeighborsClassifier(algorithm='brute',n_neighbors=k)
prob_Y_hat = np.zeros((raw_im_data.shape[0],int(np.max(trY)+1)))
total_images = raw_im_data.shape[0]
pbar = progressbar.ProgressBar(widgets=[progressbar.FormatLabel('\rProcessed %(value)d of %(max)d Images '), progressbar.Bar()], maxval=total_images, term_width=50).start()
for i in range(total_images):
mask_im=load_image(join(script_dir,missing_data_dir,raw_mask_data[i][0]), input_shape,1).reshape(np.prod(input_shape))
mask = np.logical_not(mask_im > eps) # since mask is 1 at missing locations
v_im=load_image(join(script_dir,missing_data_dir,raw_im_data[i][0]), input_shape, 255).reshape(np.prod(input_shape))
rep_mask = np.tile(mask,(trX.shape[0],1))
# Corrupt whole training set according to the current mask
corr_trX = np.multiply(trX, rep_mask)
knn_m.fit(corr_trX, trY)
prob_Y_hat[i,:] = knn_m.predict_proba(v_im.reshape(1,-1))
pbar.update(i)
pbar.finish()
return prob_Y_hat
评论列表
文章目录