def train_vae_VPflow(epoch, args, train_loader, model, optimizer):
# set loss to 0
train_loss = 0
train_re = 0
train_kl = 0
# set model in training mode
model.train()
# start training
if args.warmup == 0:
beta = 1.
else:
beta = 1.* (epoch-1) / args.warmup
if beta > 1.:
beta = 1.
print('beta: {}'.format(beta))
for batch_idx, (data, target) in enumerate(train_loader):
if args.cuda:
data, target = data.cuda(), target.cuda()
data, target = Variable(data), Variable(target)
# dynamic binarization
if args.dynamic_binarization:
x = torch.bernoulli(data)
else:
x = data
# reset gradients
optimizer.zero_grad()
# forward pass
x_mean, x_logvar, z_0, z_T, z_q_mean, z_q_logvar = model.forward(x)
# loss function
RE = log_Bernoulli(data, x_mean, average=False)
# KL
log_p_z = log_Normal_standard(z_T, dim=1)
log_q_z = log_Normal_diag(z_0, z_q_mean, z_q_logvar, dim=1)
KL = beta * (- torch.sum(log_p_z - log_q_z) )
loss = (-RE + KL) / data.size(0)
# backward pass
loss.backward()
# optimization
optimizer.step()
train_loss += loss.data[0]
train_re += (-RE / data.size(0)).data[0]
train_kl += (KL / data.size(0)).data[0]
# calculate final loss
train_loss /= len(train_loader) # loss function already averages over batch size
train_re /= len(train_loader) # re already averages over batch size
train_kl /= len(train_loader) # kl already averages over batch size
return model, train_loss, train_re, train_kl
评论列表
文章目录