def __init__(self, inpt, bbox0, presence0, batch_size, glimpse_size,
feature_extractor, rnn_units, bbox_gain=-4., att_gain=-2.5,
zoneout_prob=0., identity_init=True, attention_module=RATMAttention, normalize_glimpse=False,
debug=False, clip_bbox=False, transform_init_features=False,
transform_init_state=False, dfn_readout=False, feature_shape=None, is_training=True):
self.inpt = inpt
self.bbox0 = bbox0
self.presence0 = presence0
self.glimpse_size = glimpse_size
self.feature_extractor = feature_extractor
self.rnn_units = rnn_units
self.batch_size = batch_size
self.inpt_size = convert_shape(inpt.get_shape()[2:], np.int32)
self.bbox_gain = ensure_array(bbox_gain, 4)[np.newaxis]
self.att_gain = ensure_array(att_gain, attention_module.n_params)[np.newaxis]
self.zoneout_prob = zoneout_prob
self.identity_init = identity_init
self.attention_module = attention_module
self.normalize_glimpse = normalize_glimpse
self.debug = debug
self.clip_bbox = clip_bbox
self.transform_init_features = transform_init_features
self.transform_init_state = transform_init_state
self.dfn_readout = dfn_readout
self.feature_shape = feature_shape
self.is_training = tf.convert_to_tensor(is_training)
super(HierarchicalAttentiveRecurrentTracker, self).__init__(self.__class__.__name__)
try:
self.register(is_training)
except ValueError: pass
评论列表
文章目录