utils.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:sdp 作者: tansey 项目源码 文件源码
def get_sparse_penalty_matrix(num_classes):
    '''Creates a sparse graph-fused lasso penalty matrix (zero'th order trendfiltering)
    under the assumption that the class bins are arranged along an evenly spaced
    p-dimensional grid.'''
    bins = [np.arange(c) for c in num_classes]
    idx_map = {t: idx for idx, t in enumerate(itertools.product(*bins))}
    indices = []
    values = []
    rows = 0
    for idx1,t1 in enumerate(itertools.product(*bins)):
        for dim in xrange(len(t1)):
            if t1[dim] < (num_classes[dim]-1):
                t2 = t_offset(t1, dim, 1)
                idx2 = idx_map[t2]
                indices.append([rows, idx1])
                values.append(1)
                indices.append([rows, idx2])
                values.append(-1)
                rows += 1
    # tensorflow version
    #D_shape = [rows, np.prod(num_classes)]
    #return tf.sparse_reorder(tf.SparseTensor(indices=indices, values=values, shape=D_shape))
    # Use scipy's sparse libraries until tensorflow's sparse matrix multiplication is implemented fully
    D_shape = (rows, np.prod(num_classes))
    row_indices = [x for x,y in indices]
    col_indices = [y for x,y in indices]
    return coo_matrix((values, (row_indices, col_indices)), shape=D_shape)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号