def _detection_classifier(self, maps, ksize, cross_links=False, scope=None):
"""
Create a SegLink detection classifier on a feature layer
"""
with tf.variable_scope(scope):
seg_depth = N_SEG_CLASSES
if cross_links:
lnk_depth = N_LNK_CLASSES * (N_LOCAL_LINKS + N_CROSS_LINKS)
else:
lnk_depth = N_LNK_CLASSES * N_LOCAL_LINKS
reg_depth = OFFSET_DIM
map_depth = maps.get_shape()[3].value
seg_maps = ops.conv2d(maps, map_depth, seg_depth, ksize, 1, 'SAME', scope='conv_cls')
lnk_maps = ops.conv2d(maps, map_depth, lnk_depth, ksize, 1, 'SAME', scope='conv_lnk')
reg_maps = ops.conv2d(maps, map_depth, reg_depth, ksize, 1, 'SAME', scope='conv_reg')
return seg_maps, lnk_maps, reg_maps
评论列表
文章目录