def train_epoch(self, epoch):
self.model.train()
total_loss = 0
for batch_idx, batch in enumerate(self.train_loader):
self.optimizer.zero_grad()
output = self.model(batch.sentence_1, batch.sentence_2, batch.ext_feats)
loss = F.kl_div(output, batch.label)
total_loss += loss.data[0]
loss.backward()
self.optimizer.step()
if batch_idx % self.log_interval == 0:
self.logger.info('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, min(batch_idx * self.batch_size, len(batch.dataset.examples)),
len(batch.dataset.examples),
100. * batch_idx / (len(self.train_loader)), loss.data[0])
)
if self.use_tensorboard:
self.writer.add_scalar('sick/train/kl_div_loss', total_loss, epoch)
return total_loss
python类kl_div()的实例源码
def train_epoch(self, epoch):
self.model.train()
total_loss = 0
# since MSRVID doesn't have validation set, we manually leave-out some training data for validation
batches = math.ceil(len(self.train_loader.dataset.examples) / self.batch_size)
start_val_batch = math.floor(0.8 * batches)
left_out_val_a, left_out_val_b = [], []
left_out_val_ext_feats = []
left_out_val_labels = []
for batch_idx, batch in enumerate(self.train_loader):
# msrvid does not contain a validation set, we leave out some training data for validation to do model selection
if batch_idx >= start_val_batch:
left_out_val_a.append(batch.sentence_1)
left_out_val_b.append(batch.sentence_2)
left_out_val_ext_feats.append(batch.ext_feats)
left_out_val_labels.append(batch.label)
continue
self.optimizer.zero_grad()
output = self.model(batch.sentence_1, batch.sentence_2, batch.ext_feats)
loss = F.kl_div(output, batch.label)
total_loss += loss.data[0]
loss.backward()
self.optimizer.step()
if batch_idx % self.log_interval == 0:
self.logger.info('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, min(batch_idx * self.batch_size, len(batch.dataset.examples)),
len(batch.dataset.examples),
100. * batch_idx / (len(self.train_loader)), loss.data[0])
)
self.evaluate(self.train_evaluator, 'train')
if self.use_tensorboard:
self.writer.add_scalar('msrvid/train/kl_div_loss', total_loss, epoch)
return left_out_val_a, left_out_val_b, left_out_val_ext_feats, left_out_val_labels
def get_scores(self):
self.model.eval()
num_classes = self.dataset_cls.NUM_CLASSES
predict_classes = torch.arange(1, num_classes + 1).expand(self.batch_size, num_classes)
test_kl_div_loss = 0
predictions = []
true_labels = []
for batch in self.data_loader:
output = self.model(batch.sentence_1, batch.sentence_2, batch.ext_feats)
test_kl_div_loss += F.kl_div(output, batch.label, size_average=False).data[0]
# handle last batch which might have smaller size
if len(predict_classes) != len(batch.sentence_1):
predict_classes = torch.arange(1, num_classes + 1).expand(len(batch.sentence_1), num_classes)
if self.data_loader.device != -1:
with torch.cuda.device(self.device):
predict_classes = predict_classes.cuda()
true_labels.append((predict_classes * batch.label.data).sum(dim=1))
predictions.append((predict_classes * output.data.exp()).sum(dim=1))
del output
predictions = torch.cat(predictions).cpu().numpy()
true_labels = torch.cat(true_labels).cpu().numpy()
test_kl_div_loss /= len(batch.dataset.examples)
pearson_r = pearsonr(predictions, true_labels)[0]
spearman_r = spearmanr(predictions, true_labels)[0]
return [pearson_r, spearman_r, test_kl_div_loss], ['pearson_r', 'spearman_r', 'KL-divergence loss']
def get_scores(self):
self.model.eval()
num_classes = self.dataset_cls.NUM_CLASSES
predict_classes = torch.arange(0, num_classes).expand(self.batch_size, num_classes)
test_kl_div_loss = 0
predictions = []
true_labels = []
for batch in self.data_loader:
output = self.model(batch.sentence_1, batch.sentence_2, batch.ext_feats)
test_kl_div_loss += F.kl_div(output, batch.label, size_average=False).data[0]
# handle last batch which might have smaller size
if len(predict_classes) != len(batch.sentence_1):
predict_classes = torch.arange(0, num_classes).expand(len(batch.sentence_1), num_classes)
if self.data_loader.device != -1:
with torch.cuda.device(self.device):
predict_classes = predict_classes.cuda()
true_labels.append((predict_classes * batch.label.data).sum(dim=1))
predictions.append((predict_classes * output.data.exp()).sum(dim=1))
del output
predictions = torch.cat(predictions).cpu().numpy()
true_labels = torch.cat(true_labels).cpu().numpy()
test_kl_div_loss /= len(batch.dataset.examples)
pearson_r = pearsonr(predictions, true_labels)[0]
return [pearson_r, test_kl_div_loss], ['pearson_r', 'KL-divergence loss']
def softmax_kl_loss(input_logits, target_logits):
"""Takes softmax on both sides and returns KL divergence
Note:
- Returns the sum over all examples. Divide by the batch size afterwards
if you want the mean.
- Sends gradients to inputs but not the targets.
"""
assert input_logits.size() == target_logits.size()
input_log_softmax = F.log_softmax(input_logits, dim=1)
target_softmax = F.softmax(target_logits, dim=1)
return F.kl_div(input_log_softmax, target_softmax, size_average=False)
def _trust_region_loss(model, distribution, ref_distribution, loss, threshold):
# Compute gradients from original loss
model.zero_grad()
loss.backward(retain_graph=True)
# Gradients should be treated as constants (not using detach as volatility can creep in when double backprop is not implemented)
g = [Variable(param.grad.data.clone()) for param in model.parameters() if param.grad is not None]
model.zero_grad()
# KL divergence k ? ??0?DKL[?(?|s_i; ?_a) || ?(?|s_i; ?)]
kl = F.kl_div(distribution.log(), ref_distribution, size_average=False)
# Compute gradients from (negative) KL loss (increases KL divergence)
(-kl).backward(retain_graph=True)
k = [Variable(param.grad.data.clone()) for param in model.parameters() if param.grad is not None]
model.zero_grad()
# Compute dot products of gradients
k_dot_g = sum(torch.sum(k_p * g_p) for k_p, g_p in zip(k, g))
k_dot_k = sum(torch.sum(k_p ** 2) for k_p in k)
# Compute trust region update
if k_dot_k.data[0] > 0:
trust_factor = ((k_dot_g - threshold) / k_dot_k).clamp(min=0)
else:
trust_factor = Variable(torch.zeros(1))
# z* = g - max(0, (k^T?g - ?) / ||k||^2_2)?k
z_star = [g_p - trust_factor.expand_as(k_p) * k_p for g_p, k_p in zip(g, k)]
trust_loss = 0
for param, z_star_p in zip(model.parameters(), z_star):
trust_loss += (param * z_star_p).sum()
return trust_loss
# Trains model
def _1st_order_trpo(self, detached_policy_loss_vb, detached_policy_vb, detached_avg_policy_vb, detached_splitted_policy_vb=None):
on_policy = detached_splitted_policy_vb is None
# KL divergence k = \delta_{\phi_{\theta}} DKL[ \pi(|\phi_{\theta_a}) || \pi{|\phi_{\theta}}]
# kl_div_vb = F.kl_div(detached_policy_vb.log(), detached_avg_policy_vb, size_average=False) # NOTE: the built-in one does not work on batch
kl_div_vb = categorical_kl_div(detached_policy_vb, detached_avg_policy_vb)
# NOTE: k & g are wll w.r.t. the network output, which is detached_policy_vb
# NOTE: gradient from this part will not flow back into the model
# NOTE: that's why we are only using detached policy variables here
if on_policy:
k_vb = grad(outputs=kl_div_vb, inputs=detached_policy_vb, retain_graph=False, only_inputs=True)[0]
g_vb = grad(outputs=detached_policy_loss_vb, inputs=detached_policy_vb, retain_graph=False, only_inputs=True)[0]
else:
# NOTE NOTE NOTE !!!
# NOTE: here is why we cannot simply detach then split the policy_vb, but must split before detach
# NOTE: cos if we do that then the split cannot backtrace the grads computed in this later part of the graph
# NOTE: it would have no way to connect to the graphs in the model
k_vb = grad(outputs=(kl_div_vb.split(1, 0)), inputs=(detached_splitted_policy_vb), retain_graph=False, only_inputs=True)
g_vb = grad(outputs=(detached_policy_loss_vb.split(1, 0)), inputs=(detached_splitted_policy_vb), retain_graph=False, only_inputs=True)
k_vb = torch.cat(k_vb, 0)
g_vb = torch.cat(g_vb, 0)
kg_dot_vb = (k_vb * g_vb).sum(1, keepdim=True)
kk_dot_vb = (k_vb * k_vb).sum(1, keepdim=True)
z_star_vb = g_vb - ((kg_dot_vb - self.master.clip_1st_order_trpo) / kk_dot_vb).clamp(min=0) * k_vb
return z_star_vb
def distillation(y, teacher_scores, labels, T, alpha):
return F.kl_div(F.log_softmax(y/T), F.softmax(teacher_scores/T)) * (T*T * 2. * alpha) \
+ F.cross_entropy(y, labels) * (1. - alpha)
def kldivloss_no_reduce_test():
t = Variable(torch.randn(10, 10))
return dict(
fullname='KLDivLoss_no_reduce',
constructor=wrap_functional(
lambda i: F.kl_div(i, t.type_as(i), reduce=False)),
input_fn=lambda: torch.rand(10, 10).log(),
reference_fn=lambda i, _:
loss_reference_fns['KLDivLoss'](i, t.data.type_as(i), reduce=False),
pickle=False)
def distillation(y, teacher_scores, labels, T, alpha):
return F.kl_div(F.log_softmax(y / T), F.softmax(teacher_scores / T)) * (T * T * 2. * alpha) + F.cross_entropy(y, labels) * (1. - alpha)
def rocket_distillation(y, teacher_scores, labels, T, alpha):
return F.kl_div(F.log_softmax(y / T), F.softmax(teacher_scores / T)) * (T * T * 2. * alpha)
def forward(self, fc, seq_len):
'''
fc: [bsz, max_len, fc_size], has already passed through sigmoid layer
seq_len: [bsz]
'''
loss = Variable(torch.zeros(1), requires_grad = True)
if fc.is_cuda:
self.age_dis_trans.cuda()
loss = loss.cuda()
bsz, max_len = fc.size()[0:2]
fc = fc.view(bsz * max_len, -1)
log_prob = F.log_softmax(self.age_dis_trans(fc)).view(bsz, max_len, -1)
prob = log_prob.detach().exp()
seq_len = seq_len.long()
for i in range(bsz):
l = seq_len.data[i]-1
loss = loss + F.kl_div(log_prob[i,0:l], prob[i,1:(l+1)], False)/l
loss = loss / bsz
return loss