def optimize(model, sampler, train, valid):
"""
Optimize the model. TODO: implement early-stopping
:param model: model to optimize
:param sampler: mini-batch sampler
:param train: train user-item matrix
:param valid: validation user-item matrix
:return: None
"""
sess = tf.Session()
sess.run(tf.global_variables_initializer())
if model.feature_projection is not None:
# initialize item embedding with feature projection
sess.run(tf.assign(model.item_embeddings, model.feature_projection))
while True:
# create evaluator on validation set
validation_recall = RecallEvaluator(train, valid)
# compute recall on validate set
valid_recalls = []
# sample some users to calculate recall validation
valid_users = list(set(valid.nonzero()[0]))[:300]
for user_chunk in toolz.partition_all(300, valid_users):
scores = sess.run(model.item_scores, {model.score_user_ids: user_chunk})
valid_recalls.extend([validation_recall.eval(user, user_scores)
for user, user_scores in zip(user_chunk, scores)]
)
print("\nRecall on (sampled) validation set: {}".format(numpy.mean(valid_recalls)))
# TODO: early stopping based on validation recall
# train model
losses = []
# run n mini-batches
for _ in tqdm(range(EVALUATION_EVERY_N_BATCHES), desc="Optimizing..."):
user_pos, neg = sampler.next_batch()
_, loss = sess.run((model.optimize, model.loss),
{model.user_positive_items_pairs: user_pos,
model.negative_samples: neg})
losses.append(loss)
print("\nTraining loss {}".format(numpy.mean(losses)))
评论列表
文章目录