MSDN.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:MSDN 作者: yikang-li 项目源码 文件源码
def __init__(self,nhidden, n_object_cats, n_predicate_cats, n_vocab, voc_sign, 
                 max_word_length, MPS_iter, use_language_loss, object_loss_weight, 
                 predicate_loss_weight, 
                 dropout=False, 
                 use_kmeans_anchors=False, 
                 gate_width=128, 
                 nhidden_caption=256, 
                 nembedding = 256,
                 rnn_type='LSTM_normal', 
                 rnn_droptout=0.0, rnn_bias=False, 
                 use_region_reg=False, use_kernel=False):

        super(Hierarchical_Descriptive_Model, self).__init__(nhidden, n_object_cats, n_predicate_cats, n_vocab, voc_sign, 
                 max_word_length, MPS_iter, use_language_loss, object_loss_weight, predicate_loss_weight, 
                 dropout, use_kmeans_anchors, nhidden_caption, nembedding, rnn_type, use_region_reg)

        self.rpn = RPN(use_kmeans_anchors)
        self.roi_pool_object = RoIPool(7, 7, 1.0/16)
        self.roi_pool_phrase = RoIPool(7, 7, 1.0/16)
        self.roi_pool_region = RoIPool(7, 7, 1.0/16)
        self.fc6_obj = FC(512 * 7 * 7, nhidden, relu=True)
        self.fc7_obj = FC(nhidden, nhidden, relu=False)
        self.fc6_phrase = FC(512 * 7 * 7, nhidden, relu=True)
        self.fc7_phrase = FC(nhidden, nhidden, relu=False)
        self.fc6_region = FC(512 * 7 * 7, nhidden, relu=True)
        self.fc7_region = FC(nhidden, nhidden, relu=False)
        if MPS_iter == 0:
            self.mps = None
        else:
            self.mps = Hierarchical_Message_Passing_Structure(nhidden, dropout, 
                            gate_width=gate_width, use_kernel_function=use_kernel) # the hierarchical message passing structure
            network.weights_normal_init(self.mps, 0.01)   

        self.score_obj = FC(nhidden, self.n_classes_obj, relu=False)
        self.bbox_obj = FC(nhidden, self.n_classes_obj * 4, relu=False)
        self.score_pred = FC(nhidden, self.n_classes_pred, relu=False)
        if self.use_region_reg:
            self.bbox_region = FC(nhidden, 4, relu=False)
            network.weights_normal_init(self.bbox_region, 0.01)
        else:
            self.bbox_region = None

        self.objectiveness = FC(nhidden, 2, relu=False)

        if use_language_loss:
            self.caption_prediction = \
                Language_Model(rnn_type=self.rnn_type, ntoken=self.n_vocab, nimg=self.nhidden, nhidden=self.nhidden_caption, 
                                nembed=self.nembedding, nlayers=2, nseq=self.max_word_length, voc_sign = self.voc_sign, 
                                bias=rnn_bias, dropout=rnn_droptout) 
        else:
            self.caption_prediction = Language_Model(rnn_type=self.rnn_type, ntoken=self.n_vocab, nimg=1, nhidden=1, 
                                nembed=1, nlayers=1, nseq=1, voc_sign = self.voc_sign) # just to make the program run

        network.weights_normal_init(self.score_obj, 0.01)
        network.weights_normal_init(self.bbox_obj, 0.005)
        network.weights_normal_init(self.score_pred, 0.01)
        network.weights_normal_init(self.objectiveness, 0.01)

        self.objectiveness_loss = None
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号