def optimize(model, sampler, train, valid):
"""
Optimize the model. TODO: implement early-stopping
:param model: model to optimize
:param sampler: mini-batch sampler
:param train: train user-item matrix
:param valid: validation user-item matrix
:return: None
"""
sess = tf.Session()
sess.run(tf.global_variables_initializer())
if model.feature_projection is not None:
# initialize item embedding with feature projection
sess.run(tf.assign(model.item_embeddings, model.feature_projection))
# sample some users to calculate recall validation
valid_users = numpy.random.choice(list(set(valid.nonzero()[0])), size=1000, replace=False)
while True:
# create evaluator on validation set
validation_recall = RecallEvaluator(model, train, valid)
# compute recall on validate set
valid_recalls = []
# compute recall in chunks to utilize speedup provided by Tensorflow
for user_chunk in toolz.partition_all(100, valid_users):
valid_recalls.extend([validation_recall.eval(sess, user_chunk)])
print("\nRecall on (sampled) validation set: {}".format(numpy.mean(valid_recalls)))
# TODO: early stopping based on validation recall
# train model
losses = []
# run n mini-batches
for _ in tqdm(range(EVALUATION_EVERY_N_BATCHES), desc="Optimizing..."):
user_pos, neg = sampler.next_batch()
_, loss = sess.run((model.optimize, model.loss),
{model.user_positive_items_pairs: user_pos,
model.negative_samples: neg})
losses.append(loss)
print("\nTraining loss {}".format(numpy.mean(losses)))
评论列表
文章目录