def gen_2cigars(plot=True):
# set name
name = "2cigars"
# set generative parameters
mu1 = np.array([0,-4])
sig1 = np.array([[25,0],[0,1]])
n1 = 50
mu2 = np.array([0,4])
sig2 = np.array([[25,0],[0,1]])
n2 = 50
param = {'mu1': mu1, 'sig1': sig1, 'n1': n1,
'mu2': mu2, 'sig2': sig2, 'n2': n2}
# make labels
labels = np.array([0]*n1+[1]*n2)
# make coordinates
coord = np.concatenate((np.random.multivariate_normal(mu1,sig1,n1),
np.random.multivariate_normal(mu2,sig2,n2)))
# make dataset
ds = dataset(coord = coord, labels = labels, gen_param = param, name = name)
# plot coordinates
if plot: ds.plot_coord()
# normalize
ds.normalize_coord()
if plot: ds.plot_coord()
return ds
data_generation.py 文件源码
python
阅读 38
收藏 0
点赞 0
评论 0
评论列表
文章目录