data_generation.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:information-bottleneck 作者: djstrouse 项目源码 文件源码
def gen_halfconcentric(plot=True):

    # set name
    name = "halfconcentric"

    # set generative parameters
    nt = 80 # number of thetas
    nd = 1 # number of samples per theta
    no = nd*nt # number of samples for outer circle
    ni = 20 # number of samples for inner circle
    r = 5 # radius of outer loop
    so = .25 # gaussian noise variance of outer circle
    si = .25 # gaussian noise variance of inner circle
    thetas = -np.linspace(0,math.pi,nt)
    x = [r*math.cos(theta) for theta in thetas]
    y = [r*math.sin(theta) for theta in thetas]
    param = {'nt': nt, 'nd': nd, 'no': no, 'ni': ni, 'r': r, 'so': so, 'si': si}

    # make labels
    labels = np.array([0]*ni+[1]*no)

    # make coordinates
    coord = np.random.multivariate_normal(np.array([0,0]),si*np.eye(2),ni)
    for i in range(len(x)):
        coord = np.concatenate((coord,np.random.multivariate_normal(np.array([x[i],y[i]]),so*np.eye(2),nd)))

    # make dataset
    ds = dataset(coord = coord, labels = labels, gen_param = param, name = name)

    # plot coordinates
    if plot: ds.plot_coord()

    # normalize
    ds.normalize_coord()
    if plot: ds.plot_coord()

    return ds
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号