MSDN.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:MSDN 作者: yikang-li 项目源码 文件源码
def proposal_target_layer(object_rois, region_rois, gt_objects, gt_relationships, 
            gt_regions, n_classes_obj, voc_sign, is_training=False, graph_generation=False):

        """
        ----------
        object_rois:  (1 x H x W x A, 5) [0, x1, y1, x2, y2]
        region_rois:  (1 x H x W x A, 5) [0, x1, y1, x2, y2]
        gt_objects:   (G_obj, 5) [x1 ,y1 ,x2, y2, obj_class] int
        gt_relationships: (G_obj, G_obj) [pred_class] int (-1 for no relationship)
        gt_regions:   (G_region, 4+40) [x1, y1, x2, y2, word_index] (-1 for padding)
        # gt_ishard: (G_region, 4+40) {0 | 1} 1 indicates hard
        # dontcare_areas: (D, 4) [ x1, y1, x2, y2]
        n_classes_obj
        n_classes_pred
        is_training to indicate whether in training scheme
        ----------
        Returns
        ----------
        rois: (1 x H x W x A, 5) [0, x1, y1, x2, y2]
        labels: (1 x H x W x A, 1) {0,1,...,_num_classes-1}
        bbox_targets: (1 x H x W x A, K x4) [dx1, dy1, dx2, dy2]
        bbox_inside_weights: (1 x H x W x A, Kx4) 0, 1 masks for the computing loss
        bbox_outside_weights: (1 x H x W x A, Kx4) 0, 1 masks for the computing loss
        """

        object_rois = object_rois.data.cpu().numpy()
        region_rois = region_rois.data.cpu().numpy()

        object_labels, object_rois, bbox_targets, bbox_inside_weights, bbox_outside_weights, mat_object, \
            phrase_label, phrase_rois, mat_phrase, region_seq, region_rois, \
            bbox_targets_region, bbox_inside_weights_region, bbox_outside_weights_region, mat_region= \
            proposal_target_layer_py(object_rois, region_rois, gt_objects, gt_relationships, 
                gt_regions, n_classes_obj, voc_sign, is_training, graph_generation=graph_generation)

        # print labels.shape, bbox_targets.shape, bbox_inside_weights.shape
        if is_training:
            object_labels = network.np_to_variable(object_labels, is_cuda=True, dtype=torch.LongTensor)
            bbox_targets = network.np_to_variable(bbox_targets, is_cuda=True)
            bbox_inside_weights = network.np_to_variable(bbox_inside_weights, is_cuda=True)
            bbox_outside_weights = network.np_to_variable(bbox_outside_weights, is_cuda=True)
            phrase_label = network.np_to_variable(phrase_label, is_cuda=True, dtype=torch.LongTensor)
            region_seq = network.np_to_variable(region_seq, is_cuda=True, dtype=torch.LongTensor)
            bbox_targets_region = network.np_to_variable(bbox_targets_region, is_cuda=True)
            bbox_inside_weights_region = network.np_to_variable(bbox_inside_weights_region, is_cuda=True)
            bbox_outside_weights_region = network.np_to_variable(bbox_outside_weights_region, is_cuda=True)

        object_rois = network.np_to_variable(object_rois, is_cuda=True)
        phrase_rois = network.np_to_variable(phrase_rois, is_cuda=True)
        region_rois = network.np_to_variable(region_rois, is_cuda=True)

        return (object_rois, object_labels, bbox_targets, bbox_inside_weights, bbox_outside_weights), \
                (phrase_rois, phrase_label), \
                (region_rois, region_seq, bbox_targets_region, bbox_inside_weights_region, bbox_outside_weights_region), \
                mat_object, mat_phrase, mat_region
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号