visualization_ex.py 文件源码

python
阅读 61 收藏 0 点赞 0 评论 0

项目:Lyssandra 作者: ektormak 项目源码 文件源码
def dictionary_learn_ex():

    patch_shape = (18, 18)
    n_atoms = 225
    n_nonzero_coefs = 2
    n_plot_atoms = 100
    n_jobs = 8
    lfw_people = fetch_lfw_people(min_faces_per_person=70, resize=0.4,color=False)
    #faces = lfw_people.data
    n_imgs, h, w = lfw_people.images.shape

    imgs = []
    for i in range(n_imgs):
        img = lfw_people.images[i, :, :].reshape((h, w))
        img /= 255.
        imgs.append(img)

    print 'Extracting reference patches...'
    X = extract_patches(imgs, patch_size=patch_shape[0],scale=False,n_patches=int(1e5),verbose=True,n_jobs=n_jobs)
    print "number of patches:", X.shape[1]

    se = sparse_encoder(algorithm='bomp',params={'n_nonzero_coefs': n_nonzero_coefs}, n_jobs=n_jobs)

    odc = online_dictionary_coder(n_atoms=n_atoms, sparse_coder=se, n_epochs=1,
                                  batch_size=1000, non_neg=False, verbose=True, n_jobs=n_jobs)
    odc.fit(X)
    D = odc.D
    n_atoms_plot = 225
    plt.figure(figsize=(4.2, 4))
    for i in range(n_atoms_plot):
        plt.subplot(15, 15, i + 1)
        plt.imshow(D[:, i].reshape(patch_shape), cmap=plt.cm.gray)
        plt.xticks(())
        plt.yticks(())
    plt.show()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号