model.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:few_shot_mAP_public 作者: eleniTriantafillou 项目源码 文件源码
def __init__(self, config, reuse=False):
    self._config = config
    self._x = tf.placeholder(
        tf.float32, [None, config.height, config.width, config.channels],
        name="x")

    embedding = self.forward_pass(reuse)
    self._feats = tf.truediv(
        embedding,
        tf.sqrt(tf.reduce_sum(tf.square(embedding), 1, keep_dims=True)))

    # Number of relevant points for each query
    self._num_pos = tf.placeholder(tf.int32, [None], name="num_pos")
    self._num_neg = tf.placeholder(tf.int32, [None], name="num_neg")
    self._batch_size = tf.shape(self._x)[0]

    # The inds belonging to the positive and negative sets for each query
    self._pos_inds = tf.placeholder(tf.int32, [None, None], name="pos_inds")
    self._neg_inds = tf.placeholder(tf.int32, [None, None], name="neg_inds")

    self._n_queries_to_parse = tf.placeholder(
        tf.int32, [], name="n_queries_to_parse")

    # The solution of loss-augmented inference for each query
    self._Y_aug = tf.placeholder(
        tf.float32, [None, None, None],
        name="Y_aug")  # (num queries, num_pos, num_neg)

    self._phi_pos, self._phi_neg, self._mAP_score_std, \
    self._mAP_score_aug, self._mAP_score_GT, self._skipped_queries = self.perform_inference_mAP()
    self._loss = self.compute_loss()
    self._train_step = self.get_train_step()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号