def _run_test_als(self, use_factors_weights_cache):
with self.test_session():
col_init = np.random.rand(7, 3)
als_model = tf.contrib.factorization.WALSModel(
5, 7, 3,
col_init=col_init,
row_weights=None,
col_weights=None,
use_factors_weights_cache=use_factors_weights_cache)
als_model.initialize_op.run()
als_model.worker_init.run()
als_model.row_update_prep_gramian_op.run()
als_model.initialize_row_update_op.run()
process_input_op = als_model.update_row_factors(self._wals_inputs)[1]
process_input_op.run()
row_factors1 = [x.eval() for x in als_model.row_factors]
wals_model = tf.contrib.factorization.WALSModel(
5, 7, 3,
col_init=col_init,
row_weights=0,
col_weights=0,
use_factors_weights_cache=use_factors_weights_cache)
wals_model.initialize_op.run()
wals_model.worker_init.run()
wals_model.row_update_prep_gramian_op.run()
wals_model.initialize_row_update_op.run()
process_input_op = wals_model.update_row_factors(self._wals_inputs)[1]
process_input_op.run()
row_factors2 = [x.eval() for x in wals_model.row_factors]
for r1, r2 in zip(row_factors1, row_factors2):
self.assertAllClose(r1, r2, atol=1e-3)
# Here we test partial column updates.
sp_c = np_matrix_to_tf_sparse(INPUT_MATRIX, col_slices=[2, 0],
shuffle=True).eval()
sp_feeder = tf.sparse_placeholder(tf.float32)
feed_dict = {sp_feeder: sp_c}
als_model.col_update_prep_gramian_op.run()
als_model.initialize_col_update_op.run()
process_input_op = als_model.update_col_factors(sp_input=sp_feeder)[1]
process_input_op.run(feed_dict=feed_dict)
col_factors1 = [x.eval() for x in als_model.col_factors]
feed_dict = {sp_feeder: sp_c}
wals_model.col_update_prep_gramian_op.run()
wals_model.initialize_col_update_op.run()
process_input_op = wals_model.update_col_factors(sp_input=sp_feeder)[1]
process_input_op.run(feed_dict=feed_dict)
col_factors2 = [x.eval() for x in wals_model.col_factors]
for c1, c2 in zip(col_factors1, col_factors2):
self.assertAllClose(c1, c2, rtol=5e-3, atol=1e-2)
评论列表
文章目录