CandidateSample.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:TFCommon 作者: MU94W 项目源码 文件源码
def sampled_softmax_loss(label, logit, projection, num_sampled):
    """
    Args:
        label:
        logit:          unscaled log probabilities
        projection:     (W, b)
        num_sampled:
    """
    local_label = tf.reshape(label, shape=(-1,1))
    local_logit = tf.reshape(logit, shape=(-1, logit.get_shape()[-1].value))
    local_Wt    = tf.transpose(projection[0], perm=(1,0))
    local_b     = projection[1]
    loss_sum    = tf.nn.sampled_softmax_loss(weights=local_Wt, biases=local_b,
                                             labels=local_label,
                                             inputs=local_logit,
                                             num_sampled=num_sampled,
                                             num_classes=local_Wt.get_shape()[0].value)
    loss = tf.divide(tf.reduce_sum(loss_sum), tf.cast(tf.size(local_label), dtype=tf.float32))
    return loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号