factorization_ops_test.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def _run_test_als(self, use_factors_weights_cache):
    with self.test_session():
      col_init = np.random.rand(7, 3)
      als_model = tf.contrib.factorization.WALSModel(
          5, 7, 3,
          col_init=col_init,
          row_weights=None,
          col_weights=None,
          use_factors_weights_cache=use_factors_weights_cache)

      als_model.initialize_op.run()
      als_model.worker_init.run()
      als_model.row_update_prep_gramian_op.run()
      als_model.initialize_row_update_op.run()
      process_input_op = als_model.update_row_factors(self._wals_inputs)[1]
      process_input_op.run()
      row_factors1 = [x.eval() for x in als_model.row_factors]

      wals_model = tf.contrib.factorization.WALSModel(
          5, 7, 3,
          col_init=col_init,
          row_weights=0,
          col_weights=0,
          use_factors_weights_cache=use_factors_weights_cache)
      wals_model.initialize_op.run()
      wals_model.worker_init.run()
      wals_model.row_update_prep_gramian_op.run()
      wals_model.initialize_row_update_op.run()
      process_input_op = wals_model.update_row_factors(self._wals_inputs)[1]
      process_input_op.run()
      row_factors2 = [x.eval() for x in wals_model.row_factors]

      for r1, r2 in zip(row_factors1, row_factors2):
        self.assertAllClose(r1, r2, atol=1e-3)

      # Here we test partial column updates.
      sp_c = np_matrix_to_tf_sparse(INPUT_MATRIX, col_slices=[2, 0],
                                    shuffle=True).eval()

      sp_feeder = tf.sparse_placeholder(tf.float32)
      feed_dict = {sp_feeder: sp_c}
      als_model.col_update_prep_gramian_op.run()
      als_model.initialize_col_update_op.run()
      process_input_op = als_model.update_col_factors(sp_input=sp_feeder)[1]
      process_input_op.run(feed_dict=feed_dict)
      col_factors1 = [x.eval() for x in als_model.col_factors]

      feed_dict = {sp_feeder: sp_c}
      wals_model.col_update_prep_gramian_op.run()
      wals_model.initialize_col_update_op.run()
      process_input_op = wals_model.update_col_factors(sp_input=sp_feeder)[1]
      process_input_op.run(feed_dict=feed_dict)
      col_factors2 = [x.eval() for x in wals_model.col_factors]

      for c1, c2 in zip(col_factors1, col_factors2):
        self.assertAllClose(c1, c2, rtol=5e-3, atol=1e-2)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号