model.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:seglink 作者: bgshih 项目源码 文件源码
def _detection_classifier(self, maps, ksize, cross_links=False, scope=None):
    """
    Create a SegLink detection classifier on a feature layer
    """
    with tf.variable_scope(scope):
      seg_depth = N_SEG_CLASSES
      if cross_links:
        lnk_depth = N_LNK_CLASSES * (N_LOCAL_LINKS + N_CROSS_LINKS)
      else:
        lnk_depth = N_LNK_CLASSES * N_LOCAL_LINKS
      reg_depth = OFFSET_DIM
      map_depth = maps.get_shape()[3].value
      seg_maps = ops.conv2d(maps, map_depth, seg_depth, ksize, 1, 'SAME', scope='conv_cls')
      lnk_maps = ops.conv2d(maps, map_depth, lnk_depth, ksize, 1, 'SAME', scope='conv_lnk')
      reg_maps = ops.conv2d(maps, map_depth, reg_depth, ksize, 1, 'SAME', scope='conv_reg')
    return seg_maps, lnk_maps, reg_maps
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号