def sample_hard_negatives(img, roi_mask, out_dir, img_id, abn,
patch_size=256, neg_cutoff=.35, nb_bkg=100,
start_sample_nb=0,
bkg_dir='background', verbose=False):
'''WARNING: the definition of hns may be problematic.
There has been study showing that the context of an ROI is also useful
for classification.
'''
bkg_out = os.path.join(out_dir, bkg_dir)
basename = '_'.join([img_id, str(abn)])
img = add_img_margins(img, patch_size/2)
roi_mask = add_img_margins(roi_mask, patch_size/2)
# Get ROI bounding box.
roi_mask_8u = roi_mask.astype('uint8')
ver = (cv2.__version__).split('.')
if int(ver[0]) < 3:
contours,_ = cv2.findContours(
roi_mask_8u.copy(), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
else:
_,contours,_ = cv2.findContours(
roi_mask_8u.copy(), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
cont_areas = [ cv2.contourArea(cont) for cont in contours ]
idx = np.argmax(cont_areas) # find the largest contour.
rx,ry,rw,rh = cv2.boundingRect(contours[idx])
if verbose:
M = cv2.moments(contours[idx])
cx = int(M['m10']/M['m00'])
cy = int(M['m01']/M['m00'])
print "ROI centroid=", (cx,cy); sys.stdout.flush()
rng = np.random.RandomState(12345)
# Sample hard negative samples.
sampled_bkg = start_sample_nb
while sampled_bkg < start_sample_nb + nb_bkg:
x1,x2 = (rx - patch_size/2, rx + rw + patch_size/2)
y1,y2 = (ry - patch_size/2, ry + rh + patch_size/2)
x1 = crop_val(x1, patch_size/2, img.shape[1] - patch_size/2)
x2 = crop_val(x2, patch_size/2, img.shape[1] - patch_size/2)
y1 = crop_val(y1, patch_size/2, img.shape[0] - patch_size/2)
y2 = crop_val(y2, patch_size/2, img.shape[0] - patch_size/2)
x = rng.randint(x1, x2)
y = rng.randint(y1, y2)
if not overlap_patch_roi((x,y), patch_size, roi_mask, cutoff=neg_cutoff):
patch = img[y - patch_size/2:y + patch_size/2,
x - patch_size/2:x + patch_size/2]
patch = patch.astype('int32')
patch_img = toimage(patch, high=patch.max(), low=patch.min(),
mode='I')
filename = basename + "_%04d" % (sampled_bkg) + ".png"
fullname = os.path.join(bkg_out, filename)
patch_img.save(fullname)
sampled_bkg += 1
if verbose:
print "sampled a hns patch at (x,y) center=", (x,y)
sys.stdout.flush()
评论列表
文章目录