data_generation.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:information-bottleneck 作者: djstrouse 项目源码 文件源码
def gen_blob(plot=True):

    # set name
    name = "blob"

    # set generative parameters  
    mu1 = np.array([0,0])
    sig1 = np.eye(2)
    n1 = 90
    param = {'mu1': mu1, 'sig1': sig1, 'n1': n1}

    # make labels
    labels = np.array([0]*n1)

    # make coordinates
    coord = np.random.multivariate_normal(mu1,sig1,n1)

    # make dataset
    ds = dataset(coord = coord, labels = labels, gen_param = param, name = name)

    # plot coordinates
    if plot: ds.plot_coord()

    # normalize
    ds.normalize_coord()
    if plot: ds.plot_coord()

    return ds
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号