tracker.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:hart 作者: akosiorek 项目源码 文件源码
def _build(self):
        self.cell = AttentionCell(self.feature_extractor,
                                  self.rnn_units, self.att_gain, self.glimpse_size, self.inpt_size,
                                  self.batch_size, self.zoneout_prob,
                                  self.attention_module, self.normalize_glimpse, self.identity_init,
                                  self.debug, self.dfn_readout, self.feature_shape, is_training=self.is_training)

        first_state = self.cell.zero_state(self.batch_size, tf.float32, self.bbox0, self.presence0, self.inpt[0],
                                           self.transform_init_features, self.transform_init_state)

        raw_outputs, state = tf.nn.dynamic_rnn(self.cell, self.inpt,
                                                        initial_state=first_state,
                                                        time_major=True,
                                                        scope=tf.get_variable_scope())

        if self.debug:
            (outputs, attention, presence, glimpse) = raw_outputs[:4]
            shape = (-1, self.batch_size, 1) + tuple(self.glimpse_size)
            self.glimpse = tf.reshape(glimpse, shape, 'glimpse_shape')
            tf.summary.histogram('rnn_outputs', outputs)
        else:
            (outputs, attention, presence) = raw_outputs[:3]

        if self.dfn_readout:
            self.obj_mask_logit = tf.reshape(raw_outputs[-3], (-1, self.batch_size, 1) + tuple(self.feature_shape))
            self.obj_mask = tf.nn.sigmoid(self.obj_mask_logit)
            obj_mask_features_flat = tf.reshape(raw_outputs[-2][1:], (-1, 10))
            self.dfn_weight_decay = raw_outputs[-1]

        self.rnn_output = outputs
        self.hidden_state = state[-1]
        self.raw_presence = presence
        self.presence = tf.nn.sigmoid(self.raw_presence)

        states_flat = tf.reshape(outputs[1:], (-1, self.rnn_units), 'flatten_states')
        if self.dfn_readout:
            states_flat = tf.concat(axis=1, values=(states_flat, obj_mask_features_flat))

        hidden_to_bbox = MLP(states_flat, self.rnn_units, 4, transfer=tf.nn.tanh, name='fc_h2bbox',
                             weight_init=self.cell._rec_init, bias_init=tf.constant_initializer())

        if self.debug:
            tf.summary.histogram('bbox_diff', hidden_to_bbox)

        attention = tf.reshape(attention, (-1, self.batch_size, 1, self.cell.att_size), 'shape_attention')
        self.attention = tf.concat(axis=0, values=(self.cell.att0[tf.newaxis], attention[:-1]))
        self.att_pred_bbox = self.cell.attention.attention_to_bbox(self.attention)
        self.att_pred_bbox_wo_bias = self.cell.attention.attention_to_bbox(self.attention - self.cell.att_bias)
        self.att_region = self.cell.attention.attention_region(self.attention)

        pred_bbox_delta = tf.reshape(hidden_to_bbox.output, (-1, self.batch_size, 1, 4), 'shape_pred_deltas')
        p = tf.zeros_like(pred_bbox_delta[0])[tf.newaxis]
        p = tf.concat(axis=0, values=(p, pred_bbox_delta))

        self.corr_pred_bbox = p * np.tile(self.inpt_size[:2], (2,)).reshape(1, 4)
        self.pred_bbox = self.att_pred_bbox_wo_bias + self.corr_pred_bbox
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号