def compute_loss_and_gradient(self, x):
self.optimizer.zero_grad()
recon_x, z_mean, z_var = self.model_eval(x)
binary_cross_entropy = functional.binary_cross_entropy(recon_x, x.view(-1, 784))
# Uses analytical KL divergence expression for D_kl(q(z|x) || p(z))
# Refer to Appendix B from VAE paper:
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
# (https://arxiv.org/abs/1312.6114)
kl_div = -0.5 * torch.sum(1 + z_var.log() - z_mean.pow(2) - z_var)
kl_div /= self.args.batch_size * 784
loss = binary_cross_entropy + kl_div
if self.mode == TRAIN:
loss.backward()
self.optimizer.step()
return loss.data[0]
python类binary_cross_entropy()的实例源码
sigmoid_with_binary_cross_entropy.py 文件源码
项目:pytorch-misc
作者: Jiaming-Liu
项目源码
文件源码
阅读 24
收藏 0
点赞 0
评论 0
def forward(self, input, target):
input.sigmoid_()
self.save_for_backward(input, target)
return F.binary_cross_entropy(
torch.autograd.Variable(input, requires_grad=False),
torch.autograd.Variable(target, requires_grad=False),
weight=self.weight, size_average=self.size_average
).data
def loss_function(recon_x, x, mu, logvar):
BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784))
# see Appendix B from VAE paper:
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
# https://arxiv.org/abs/1312.6114
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
# Normalise by same number of elements as in reconstruction
KLD /= args.batch_size * 784
return BCE + KLD
def _backward(self):
# TODO: we need to have a custom loss function to take mask into account
# TODO: pass in this way might be too unefficient, but it's ok for now
if self.training:
self.optimizer.zero_grad()
loss_vb = F.binary_cross_entropy(input=self.output_vb.transpose(0, 1).contiguous().view(1, -1),
target=self.target_vb.transpose(0, 1).contiguous().view(1, -1),
weight=self.mask_ts.transpose(0, 1).contiguous().view(1, -1))
loss_vb /= self.batch_size
if self.training:
loss_vb.backward()
self.optimizer.step()
return loss_vb.data[0]
def bceloss_no_reduce_test():
t = torch.randn(15, 10).gt(0).double()
return dict(
fullname='BCELoss_no_reduce',
constructor=wrap_functional(
lambda i: F.binary_cross_entropy(i, Variable(t.type_as(i.data)), reduce=False)),
input_fn=lambda: torch.rand(15, 10).clamp_(2e-2, 1 - 2e-2),
reference_fn=lambda i, m: -(t * i.log() + (1 - t) * (1 - i).log()),
check_gradgrad=False,
pickle=False)
def bceloss_weights_no_reduce_test():
t = torch.randn(15, 10).gt(0).double()
weights = torch.rand(10)
return dict(
fullname='BCELoss_weights_no_reduce',
constructor=wrap_functional(
lambda i: F.binary_cross_entropy(i, Variable(t.type_as(i.data)),
weight=weights.type_as(i.data), reduce=False)),
input_fn=lambda: torch.rand(15, 10).clamp_(2e-2, 1 - 2e-2),
reference_fn=lambda i, m: -(t * i.log() + (1 - t) * (1 - i).log()) * weights,
check_gradgrad=False,
pickle=False)
def ptc_dis_step(self):
"""
Train the patch discriminator.
"""
data = self.data
params = self.params
self.ae.eval()
self.ptc_dis.train()
bs = params.batch_size
# batch / encode / discriminate
batch_x, batch_y = data.train_batch(bs)
flipped = flip_attributes(batch_y, params, 'all')
_, dec_outputs = self.ae(Variable(batch_x.data, volatile=True), flipped)
real_preds = self.ptc_dis(batch_x)
fake_preds = self.ptc_dis(Variable(dec_outputs[-1].data))
y_fake = Variable(torch.FloatTensor(real_preds.size())
.fill_(params.smooth_label).cuda())
# loss / optimize
loss = F.binary_cross_entropy(real_preds, 1 - y_fake)
loss += F.binary_cross_entropy(fake_preds, y_fake)
self.stats['ptc_dis_costs'].append(loss.data[0])
self.ptc_dis_optimizer.zero_grad()
loss.backward()
if params.clip_grad_norm:
clip_grad_norm(self.ptc_dis.parameters(), params.clip_grad_norm)
self.ptc_dis_optimizer.step()
def binaryXloss(logits, label):
mask = (label.view(-1) != VOID_LABEL)
nonvoid = mask.long().sum()
if nonvoid == 0:
# only void pixels, the gradients should be 0
return logits.sum() * 0.
# if nonvoid == mask.numel():
# # no void pixel, use builtin
# return F.cross_entropy(logits, Variable(label))
target = label.contiguous().view(-1)[mask]
logits = logits.contiguous().view(-1)[mask]
# loss = F.binary_cross_entropy(logits, Variable(target.float()))
loss = StableBCELoss()(logits, Variable(target.float()))
return loss
def _train(args, T, model, shared_model, shared_average_model, optimiser, policies, Qs, Vs, actions, rewards, Qret, average_policies, target_class, pred_class, old_policies=None):
off_policy = old_policies is not None
policy_loss, value_loss, class_loss = 0, 0, 0
# Calculate n-step returns in forward view, stepping backwards from the last state
t = len(rewards)
for i in reversed(range(t)):
# Importance sampling weights ? ? ?(?|s_i) / µ(?|s_i); 1 for on-policy
rho = off_policy and policies[i].detach() / old_policies[i] or Variable(torch.ones(1, ACTION_SIZE))
# Qret ? r_i + ?Qret
Qret = rewards[i] + args.discount * Qret
# Advantage A ? Qret - V(s_i; ?)
A = Qret - Vs[i]
# Log policy log(?(a_i|s_i; ?))
log_prob = policies[i].gather(1, actions[i]).log()
# g ? min(c, ?_a_i)????log(?(a_i|s_i; ?))?A
single_step_policy_loss = -(rho.gather(1, actions[i]).clamp(max=args.trace_max) * log_prob * A).mean(0) # Average over batch
# Off-policy bias correction
if off_policy:
# g ? g + ?_a [1 - c/?_a]_+??(a|s_i; ?)????log(?(a|s_i; ?))?(Q(s_i, a; ?) - V(s_i; ?)
bias_weight = (1 - args.trace_max / rho).clamp(min=0) * policies[i]
single_step_policy_loss -= (bias_weight * policies[i].log() * (Qs[i].detach() - Vs[i].expand_as(Qs[i]).detach())).sum(1).mean(0)
if args.trust_region:
# Policy update d? ? d? + ??/???z*
policy_loss += _trust_region_loss(model, policies[i], average_policies[i], single_step_policy_loss, args.trust_region_threshold)
else:
# Policy update d? ? d? + ??/???g
policy_loss += single_step_policy_loss
# Entropy regularisation d? ? d? - ????H(?(s_i; ?))
policy_loss += args.entropy_weight * -(policies[i].log() * policies[i]).sum(1).mean(0)
# Value update d? ? d? - ???1/2?(Qret - Q(s_i, a_i; ?))^2
Q = Qs[i].gather(1, actions[i])
value_loss += ((Qret - Q) ** 2 / 2).mean(0) # Least squares loss
# Truncated importance weight ?¯_a_i = min(1, ?_a_i)
truncated_rho = rho.gather(1, actions[i]).clamp(max=1)
# Qret ? ?¯_a_i?(Qret - Q(s_i, a_i; ?)) + V(s_i; ?)
Qret = truncated_rho * (Qret - Q.detach()) + Vs[i].detach()
# Train classification loss
class_loss += F.binary_cross_entropy(pred_class[i], target_class)
# Optionally normalise loss by number of time steps
if not args.no_time_normalisation:
policy_loss /= t
value_loss /= t
class_loss /= t
# Update networks
_update_networks(args, T, model, shared_model, shared_average_model, policy_loss + value_loss + class_loss, optimiser)
# Acts and trains model
def train(epoch):
for batch, (left, right) in enumerate(training_data_loader):
if args.direction == 'lr':
input.data.resize_(left.size()).copy_(left)
target.data.resize_(right.size()).copy_(right)
else:
input.data.resize_(right.size()).copy_(right)
target.data.resize_(left.size()).copy_(left)
## Discriminator
netD.zero_grad()
# real
D_real = netD(input, target)
ones_label.data.resize_(D_real.size()).fill_(1)
zeros_label.data.resize_(D_real.size()).fill_(0)
D_loss_real = F.binary_cross_entropy(D_real, ones_label)
D_x_y = D_real.data.mean()
# fake
G_fake = netG(input)
D_fake = netD(input, G_fake.detach())
D_loss_fake = F.binary_cross_entropy(D_fake, zeros_label)
D_x_gx = D_fake.data.mean()
D_loss = D_loss_real + D_loss_fake
D_loss.backward()
D_solver.step()
## Generator
netG.zero_grad()
G_fake = netG(input)
D_fake = netD(input, G_fake)
D_x_gx_2 = D_fake.data.mean()
G_loss = F.binary_cross_entropy(D_fake, ones_label) + 100 * F.smooth_l1_loss(G_fake, target)
G_loss.backward()
G_solver.step()
## debug
if (batch + 1) % 100 == 0:
print('[TRAIN] Epoch[{}]({}/{}); D_loss: {:.4f}; G_loss: {:.4f}; D(x): {:.4f} D(G(z)): {:.4f}/{:.4f}'.format(
epoch, batch + 1, len(training_data_loader), D_loss.data[0], G_loss.data[0], D_x_y, D_x_gx, D_x_gx_2))
def autoencoder_step(self):
"""
Train the autoencoder with cross-entropy loss.
Train the encoder with discriminator loss.
"""
data = self.data
params = self.params
self.ae.train()
if params.n_lat_dis:
self.lat_dis.eval()
if params.n_ptc_dis:
self.ptc_dis.eval()
if params.n_clf_dis:
self.clf_dis.eval()
bs = params.batch_size
# batch / encode / decode
batch_x, batch_y = data.train_batch(bs)
enc_outputs, dec_outputs = self.ae(batch_x, batch_y)
# autoencoder loss from reconstruction
loss = params.lambda_ae * ((batch_x - dec_outputs[-1]) ** 2).mean()
self.stats['rec_costs'].append(loss.data[0])
# encoder loss from the latent discriminator
if params.lambda_lat_dis:
lat_dis_preds = self.lat_dis(enc_outputs[-1 - params.n_skip])
lat_dis_loss = get_attr_loss(lat_dis_preds, batch_y, True, params)
loss = loss + get_lambda(params.lambda_lat_dis, params) * lat_dis_loss
# decoding with random labels
if params.lambda_ptc_dis + params.lambda_clf_dis > 0:
flipped = flip_attributes(batch_y, params, 'all')
dec_outputs_flipped = self.ae.decode(enc_outputs, flipped)
# autoencoder loss from the patch discriminator
if params.lambda_ptc_dis:
ptc_dis_preds = self.ptc_dis(dec_outputs_flipped[-1])
y_fake = Variable(torch.FloatTensor(ptc_dis_preds.size())
.fill_(params.smooth_label).cuda())
ptc_dis_loss = F.binary_cross_entropy(ptc_dis_preds, 1 - y_fake)
loss = loss + get_lambda(params.lambda_ptc_dis, params) * ptc_dis_loss
# autoencoder loss from the classifier discriminator
if params.lambda_clf_dis:
clf_dis_preds = self.clf_dis(dec_outputs_flipped[-1])
clf_dis_loss = get_attr_loss(clf_dis_preds, flipped, False, params)
loss = loss + get_lambda(params.lambda_clf_dis, params) * clf_dis_loss
# check NaN
if (loss != loss).data.any():
logger.error("NaN detected")
exit()
# optimize
self.ae_optimizer.zero_grad()
loss.backward()
if params.clip_grad_norm:
clip_grad_norm(self.ae.parameters(), params.clip_grad_norm)
self.ae_optimizer.step()