CML.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:Collaborative-metric-learning 作者: KiM55 项目源码 文件源码
def optimize(model, sampler, train, valid):
    """
    Optimize the model. TODO: implement early-stopping
    :param model: model to optimize
    :param sampler: mini-batch sampler
    :param train: train user-item matrix
    :param valid: validation user-item matrix
    :return: None
    """
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    if model.feature_projection is not None:
        # initialize item embedding with feature projection
        sess.run(tf.assign(model.item_embeddings, model.feature_projection))
    while True:
        # create evaluator on validation set
        validation_recall = RecallEvaluator(train, valid)
        # compute recall on validate set
        valid_recalls = []
        # sample some users to calculate recall validation
        valid_users = list(set(valid.nonzero()[0]))[:300]
        for user_chunk in toolz.partition_all(300, valid_users):
            scores = sess.run(model.item_scores, {model.score_user_ids: user_chunk})
            valid_recalls.extend([validation_recall.eval(user, user_scores)
                                  for user, user_scores in zip(user_chunk, scores)]
                                 )
        print("\nRecall on (sampled) validation set: {}".format(numpy.mean(valid_recalls)))
        # TODO: early stopping based on validation recall


        # train model
        losses = []
        # run n mini-batches
        for _ in tqdm(range(EVALUATION_EVERY_N_BATCHES), desc="Optimizing..."):
            user_pos, neg = sampler.next_batch()
            _, loss = sess.run((model.optimize, model.loss),
                               {model.user_positive_items_pairs: user_pos,
                                model.negative_samples: neg})
            losses.append(loss)
        print("\nTraining loss {}".format(numpy.mean(losses)))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号