ImportanceEstimator.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:LargeNumberClasses 作者: maodayezheng 项目源码 文件源码
def loss(self, x, h, q=None):
        """
            Calculate the estimate loss of Importance sampling approximation

            @Param x(NxD): The target word or batch
            @Param h(NxD): This is usually the output of neural network
            @Param q(N): The Weight of target
        """
        # K
        weights = self.get_sample_weights()
        tf.Assert(tf.equal(weights, 0.0), [weights])
        if weights is None:
            raise ValueError("sample weights must be set")
        # KxD
        samples = self.get_samples()
        if samples is None:
            raise ValueError("samples must be set")
        # N
        target_scores = tf.reduce_sum(x * h, 1)
        self.target_exp_ = tf.exp(target_scores)
        # N x K
        samples_scores = tf.matmul(h, samples, transpose_b=True)
        # N
        exp_weight = tf.exp(samples_scores) / weights
        self.Z_ = tf.reduce_sum(tf.check_numerics(exp_weight, "each Z "), 1)

        # The loss of each element in target
        # N
        element_loss = target_scores - tf.log(q) - tf.log(self.Z_)
        loss = tf.reduce_mean(element_loss)
        return -loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号