model.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:SentenceOrdering_PTR 作者: JerrikEph 项目源码 文件源码
def add_loss_op(self, logits):
        def seq_loss(logits_tensor, label_tensor, length_tensor):
            """
            Args
                logits_tensor: shape (batch_size*time_steps_de, time_steps_en)
                label_tensor: shape (batch_size, time_steps_de), label id 1D tensor
                length_tensor: shape(batch_size)
            Return
                loss: A scalar tensor, mean error
            """

            labels = tf.reshape(label_tensor, shape=(-1,))
            loss_flat = tf.nn.sparse_softmax_cross_entropy_with_logits(logits_tensor, labels, name='sparse_softmax')
            losses = tf.reshape(loss_flat, shape=tf.shape(label_tensor)) #(batch_size, tstp_de)
            length_mask = tf.sequence_mask(length_tensor, tf.shape(losses)[1], dtype=tf.float32, name='length_mask')
            losses_sum = tf.reduce_sum(losses*length_mask, reduction_indices=[1]) #(batch_size)
            losses_mean = losses_sum / (tf.to_float(length_tensor)+1e-20) #(batch_size)
            loss = tf.reduce_mean(losses_mean) #scalar
            return loss 

        reg_loss = tf.add_n([tf.nn.l2_loss(v) for v in tf.trainable_variables() if v != self.embedding]) *self.config.reg
        valid_loss = seq_loss(logits, self.decoder_label, self.decoder_tstps)
        train_loss = reg_loss + valid_loss
        return train_loss, valid_loss, reg_loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号