vae_dp.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:VAE_NBP 作者: bobchennan 项目源码 文件源码
def train(epoch, prior):
    model.train()
    train_loss = 0
    #prior = BayesianGaussianMixture(n_components=1, covariance_type='diag')
    tmp = []
    for (data,_) in train_loader:
        data = Variable(data)
        if args.cuda:
            data = data.cuda()
        recon_batch, mu, logvar, z = model(data)
        tmp.append(z.cpu().data.numpy())
    print('Update Prior')
    prior.fit(np.vstack(tmp))
    print('prior: '+str(prior.weights_))
    for batch_idx, (data, _) in enumerate(train_loader):
        data = Variable(data)
        if args.cuda:
            data = data.cuda()
        optimizer.zero_grad()
        recon_batch, mu, logvar, z = model(data)
        loss = loss_function(recon_batch, data, mu, logvar, prior, z)
        loss.backward()
        train_loss += loss.data[0]
        optimizer.step()
        #if batch_idx % args.log_interval == 0:
        #    print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        #        epoch, batch_idx * len(data), len(train_loader.dataset),
        #        100. * batch_idx / len(train_loader),
        #        loss.data[0] / len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))
    return prior
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号