def train(epoch, prior):
prior = BayesianGaussianMixture(n_components=50, covariance_type='diag', n_init=5, max_iter=1000)
tmp = []
for (data,_) in train_loader:
#print(data.numpy().shape)
tmp.append(data.numpy().reshape(data.numpy().shape[0],-1))
prior.fit(np.vstack(tmp))
return prior
评论列表
文章目录