factorization_ops_test.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def _run_test_als_transposed(self, use_factors_weights_cache):
    with self.test_session():
      col_init = np.random.rand(7, 3)
      als_model = tf.contrib.factorization.WALSModel(
          5, 7, 3,
          col_init=col_init,
          row_weights=None,
          col_weights=None,
          use_factors_weights_cache=use_factors_weights_cache)

      als_model.initialize_op.run()
      als_model.worker_init.run()

      wals_model = tf.contrib.factorization.WALSModel(
          5, 7, 3,
          col_init=col_init,
          row_weights=[0] * 5,
          col_weights=[0] * 7,
          use_factors_weights_cache=use_factors_weights_cache)
      wals_model.initialize_op.run()
      wals_model.worker_init.run()
      sp_feeder = tf.sparse_placeholder(tf.float32)
      # Here test partial row update with identical inputs but with transposed
      # input for als.
      sp_r_t = np_matrix_to_tf_sparse(INPUT_MATRIX, [3, 1],
                                      transpose=True).eval()
      sp_r = np_matrix_to_tf_sparse(INPUT_MATRIX, [3, 1]).eval()

      feed_dict = {sp_feeder: sp_r_t}
      als_model.row_update_prep_gramian_op.run()
      als_model.initialize_row_update_op.run()
      process_input_op = als_model.update_row_factors(sp_input=sp_feeder,
                                                      transpose_input=True)[1]
      process_input_op.run(feed_dict=feed_dict)
      # Only updated row 1 and row 3, so only compare these rows since others
      # have randomly initialized values.
      row_factors1 = [als_model.row_factors[0].eval()[1],
                      als_model.row_factors[0].eval()[3]]

      feed_dict = {sp_feeder: sp_r}
      wals_model.row_update_prep_gramian_op.run()
      wals_model.initialize_row_update_op.run()
      process_input_op = wals_model.update_row_factors(sp_input=sp_feeder)[1]
      process_input_op.run(feed_dict=feed_dict)
      # Only updated row 1 and row 3, so only compare these rows since others
      # have randomly initialized values.
      row_factors2 = [wals_model.row_factors[0].eval()[1],
                      wals_model.row_factors[0].eval()[3]]
      for r1, r2 in zip(row_factors1, row_factors2):
        self.assertAllClose(r1, r2, atol=1e-3)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号