def update_diff(self, accuracy, batch_idxs, batch_losses, batch_plens, loss_w=0.5, smooth_w=0.5):
with tf.control_dependencies(
[tf.assign(self.acc_coef, accuracy)]
):
current_entropy = tf.gather(self.seq_entropy, batch_idxs)
loss_coef = batch_losses / (tf.reduce_max(batch_losses) + 1e-8)
new_entropy = (loss_coef * loss_w) + (batch_plens / self.max_plen * (1 - loss_w))
updated_entropy = (current_entropy * smooth_w) + (new_entropy * (1 - smooth_w))
update_op = tf.scatter_update(self.seq_entropy, batch_idxs, updated_entropy)
return update_op
评论列表
文章目录