encdec.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:TextGAN 作者: ankitkv 项目源码 文件源码
def mle_loss(self, outputs, targets):
        '''Maximum likelihood estimation loss.'''
        present_mask = tf.greater(targets, 0, name='present_mask')
        # don't enfoce loss on true <unk>'s
        unk_mask = tf.not_equal(targets, self.vocab.unk_index, name='unk_mask')
        mask = tf.cast(tf.logical_and(present_mask, unk_mask), tf.float32)
        output = tf.reshape(tf.concat(1, outputs), [-1, cfg.hidden_size])
        if self.training and cfg.softmax_samples < len(self.vocab.vocab):
            targets = tf.reshape(targets, [-1, 1])
            mask = tf.reshape(mask, [-1])
            loss = tf.nn.sampled_softmax_loss(self.softmax_w, self.softmax_b, output, targets,
                                              cfg.softmax_samples, len(self.vocab.vocab))
            loss *= mask
        else:
            logits = tf.nn.bias_add(tf.matmul(output, tf.transpose(self.softmax_w),
                                              name='softmax_transform_mle'), self.softmax_b)
            loss = tf.nn.seq2seq.sequence_loss_by_example([logits],
                                                          [tf.reshape(targets, [-1])],
                                                          [tf.reshape(mask, [-1])])
        return tf.reshape(loss, [cfg.batch_size, -1])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号