def add_loss_op(self, logits):
def seq_loss(logits_tensor, label_tensor, length_tensor):
"""
Args
logits_tensor: shape (batch_size*time_steps_de, time_steps_en)
label_tensor: shape (batch_size, time_steps_de), label id 1D tensor
length_tensor: shape(batch_size)
Return
loss: A scalar tensor, mean error
"""
labels = tf.reshape(label_tensor, shape=(-1,))
loss_flat = tf.nn.sparse_softmax_cross_entropy_with_logits(logits_tensor, labels, name='sparse_softmax')
losses = tf.reshape(loss_flat, shape=tf.shape(label_tensor)) #(batch_size, tstp_de)
length_mask = tf.sequence_mask(length_tensor, tf.shape(losses)[1], dtype=tf.float32, name='length_mask')
losses_sum = tf.reduce_sum(losses*length_mask, reduction_indices=[1]) #(batch_size)
losses_mean = losses_sum / (tf.to_float(length_tensor)+1e-20) #(batch_size)
loss = tf.reduce_mean(losses_mean) #scalar
return loss
reg_loss = tf.add_n([tf.nn.l2_loss(v) for v in tf.trainable_variables() if v != self.embedding]) *self.config.reg
valid_loss = seq_loss(logits, self.decoder_label, self.decoder_tstps)
train_loss = reg_loss + valid_loss
return train_loss, valid_loss, reg_loss
评论列表
文章目录