def label2sm_fm(self, label):
def get_point(a_list, idx):
w, h = a_list[idx * 2: idx * 2 + 2]
return int(w * self.fm_width), int(h * self.fm_height)
def p8_distance(h1, h2, w1, w2):
return max(math.fabs(h1 - h2), math.fabs(w1 - w2))
def p4_distance(h1, h2, w1, w2):
return math.fabs(h1 - h2) + math.fabs(w1 - w2)
def draw(img, center, idx):
w0, h0 = center
height, width = img.shape
for w in xrange(max(0, w0-self.radius), min(width, w0+self.radius+1)):
for h in xrange(max(0, h0-self.radius), min(height, h0+self.radius+1)):
if(p8_distance(h, h0, w, w0) < self.radius):
img[h, w] = idx + 1
fm_label = np.zeros((label.shape[0], self.fm_height, self.fm_width))
for batch_idx in xrange(len(fm_label)):
for ii in xrange(self.points_num):
w, h = get_point(label[batch_idx], ii)
draw(fm_label[batch_idx], (w, h), ii)
# fm_label = fm_label.astype(np.int32)
return fm_label.astype(np.int32)
read_data.py 文件源码
python
阅读 25
收藏 0
点赞 0
评论 0
评论列表
文章目录