def train(epoch, prior):
model.train()
train_loss = 0
#prior = BayesianGaussianMixture(n_components=1, covariance_type='diag')
tmp = []
for (data,_) in train_loader:
data = Variable(data)
if args.cuda:
data = data.cuda()
recon_batch, mu, logvar, z = model(data)
tmp.append(z.cpu().data.numpy())
print('Update Prior')
prior.fit(np.vstack(tmp))
print('prior: '+str(prior.weights_))
for batch_idx, (data, _) in enumerate(train_loader):
data = Variable(data)
if args.cuda:
data = data.cuda()
optimizer.zero_grad()
recon_batch, mu, logvar, z = model(data)
loss = loss_function(recon_batch, data, mu, logvar, prior, z)
loss.backward()
train_loss += loss.data[0]
optimizer.step()
#if batch_idx % args.log_interval == 0:
# print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
# epoch, batch_idx * len(data), len(train_loader.dataset),
# 100. * batch_idx / len(train_loader),
# loss.data[0] / len(data)))
print('====> Epoch: {} Average loss: {:.4f}'.format(
epoch, train_loss / len(train_loader.dataset)))
return prior
评论列表
文章目录