test_larcqy.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:factorix 作者: gbouchar 项目源码 文件源码
def test_matrix_factorization(verbose=False):
    np.random.seed(1)
    n, m, rank = 7, 6, 3
    mat = np.random.randn(n, rank).dot(np.random.randn(rank, m))
    tuples = [([i, n + j], mat[i, j]) for i in range(n) for j in range(m)]
    tuple_iterable = data_to_batches(tuples, minibatch_size=n * m)
    sampler, (x, y) = feed_dict_sampler(tuple_iterable, types=[np.int64, np.float32])
    emb_var = tf.Variable(tf.cast(np.random.randn(n + m, rank), 'float32'))
    offset = tf.Variable(tf.cast(1.0, 'float32'))
    loss_op = tf.reduce_mean(tf.square(tf.reduce_sum(tf.reduce_prod(tf.gather(emb_var, x), 1), 1) + offset - y))
    emb, offset_val = learn(loss_op, sampler, max_epochs=200, variables=[emb_var, offset])
    mat_est = emb[:n, :].dot(emb[n:, :].T)
    if verbose:
        print(np.linalg.norm(mat_est - mat) ** 2)  # we should have recovered the low-rank matrix
    else:
        assert (np.linalg.norm(mat_est - mat) < 1e-3)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号