training.py 文件源码

python
阅读 38 收藏 0 点赞 0 评论 0

项目:vae_vpflows 作者: jmtomczak 项目源码 文件源码
def train_vae_VPflow(epoch, args, train_loader, model, optimizer):
    # set loss to 0
    train_loss = 0
    train_re = 0
    train_kl = 0
    # set model in training mode
    model.train()

    # start training
    if args.warmup == 0:
        beta = 1.
    else:
        beta = 1.* (epoch-1) / args.warmup
        if beta > 1.:
            beta = 1.
    print('beta: {}'.format(beta))

    for batch_idx, (data, target) in enumerate(train_loader):
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)
        # dynamic binarization
        if args.dynamic_binarization:
                x = torch.bernoulli(data)
        else:
            x = data
        # reset gradients
        optimizer.zero_grad()
        # forward pass
        x_mean, x_logvar, z_0, z_T, z_q_mean, z_q_logvar = model.forward(x)
        # loss function
        RE = log_Bernoulli(data, x_mean, average=False)
        # KL
        log_p_z = log_Normal_standard(z_T, dim=1)
        log_q_z = log_Normal_diag(z_0, z_q_mean, z_q_logvar, dim=1)
        KL = beta * (- torch.sum(log_p_z - log_q_z) )

        loss = (-RE + KL) / data.size(0)
        # backward pass
        loss.backward()
        # optimization
        optimizer.step()

        train_loss += loss.data[0]
        train_re += (-RE / data.size(0)).data[0]
        train_kl += (KL / data.size(0)).data[0]

    # calculate final loss
    train_loss /= len(train_loader)  # loss function already averages over batch size
    train_re /= len(train_loader)  # re already averages over batch size
    train_kl /= len(train_loader)  # kl already averages over batch size

    return model, train_loss, train_re, train_kl
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号