data_generation.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:information-bottleneck 作者: djstrouse 项目源码 文件源码
def gen_easytest(plot=True):

    # set name
    name = "easytest"

    n = 10
    # set generative parameters  
    mu1 = np.array([0,0])
    sig1 = np.eye(2)
    n1 = n
    mu2 = np.array([math.sqrt(75),5])
    sig2 = np.eye(2)
    n2 = n
    mu3 = np.array([0,10])
    sig3 = np.eye(2)
    n3 = n
    param = {'mu1': mu1, 'sig1': sig1, 'n1': n1,
             'mu2': mu2, 'sig2': sig2, 'n2': n2,
             'mu3': mu3, 'sig3': sig3, 'n3': n3}

    # make labels
    labels = np.array([0]*n1+[1]*n2+[2]*n3)

    # make coordinates
    coord = np.concatenate((np.random.multivariate_normal(mu1,sig1,n1),
                            np.random.multivariate_normal(mu2,sig2,n2),
                            np.random.multivariate_normal(mu3,sig3,n3)))

    # make dataset
    ds = dataset(coord = coord, labels = labels, gen_param = param, name = name)

    # plot coordinates
    if plot: ds.plot_coord()

    # normalize
    ds.normalize_coord()
    if plot: ds.plot_coord()

    return ds
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号