data.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:GVIN 作者: sufengniu 项目源码 文件源码
def theta_matrix(coord, adj, preload=True, train=True):
    print "creating adjacent theta matrix ..."
    if preload is True:
        if train is True:
            theta_matrix = np.load('../data/theta_matrix_train_n_100.npy')
        else:
            theta_matrix = np.load('../data/theta_matrix_test_n_100.npy')
    else:
        theta_matrix = []
        for i in tqdm(range(coord.shape[0])):
            for j in range(coord.shape[1]):
                theta_row = angle(coord[i,adj[i][j].nonzero()[1],:] - coord[i,j,:])
                col_indice = adj[i][j].nonzero()[1]
                row_indice = (np.zeros(col_indice.shape[0])).astype(int32)
                if j == 0:
                    theta_matrix_tmp = csc_matrix((theta_row, (row_indice, col_indice)), shape=(1,coord.shape[1]))
                else:
                    theta_matrix_tmp = scipy.sparse.vstack((theta_matrix_tmp, csc_matrix((theta_row, (row_indice, col_indice)), shape=(1,coord.shape[1]))))
            theta_matrix.append(theta_matrix_tmp)
        theta_matrix = np.array(theta_matrix)
    return theta_matrix
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号