def gen_easytest(plot=True):
# set name
name = "easytest"
n = 10
# set generative parameters
mu1 = np.array([0,0])
sig1 = np.eye(2)
n1 = n
mu2 = np.array([math.sqrt(75),5])
sig2 = np.eye(2)
n2 = n
mu3 = np.array([0,10])
sig3 = np.eye(2)
n3 = n
param = {'mu1': mu1, 'sig1': sig1, 'n1': n1,
'mu2': mu2, 'sig2': sig2, 'n2': n2,
'mu3': mu3, 'sig3': sig3, 'n3': n3}
# make labels
labels = np.array([0]*n1+[1]*n2+[2]*n3)
# make coordinates
coord = np.concatenate((np.random.multivariate_normal(mu1,sig1,n1),
np.random.multivariate_normal(mu2,sig2,n2),
np.random.multivariate_normal(mu3,sig3,n3)))
# make dataset
ds = dataset(coord = coord, labels = labels, gen_param = param, name = name)
# plot coordinates
if plot: ds.plot_coord()
# normalize
ds.normalize_coord()
if plot: ds.plot_coord()
return ds
data_generation.py 文件源码
python
阅读 27
收藏 0
点赞 0
评论 0
评论列表
文章目录