losses.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:factorix 作者: gbouchar 项目源码 文件源码
def loss_func_softmax(pred, gold):
    """softmax function with integers as the second argument (instead of zero-one encoding matrix)

    Args:
        pred: log-odds where the last dimension is the number of labels
        gold: integer array the same size as pred but the last dimension which is 1

    Returns:
        the softmax values applied to the predictions

    """
    pred = tf.reshape(pred, [-1, pred.get_shape()[-1].value])
    gold = tf.reshape(gold, [pred.get_shape()[0].value])
    n = pred.get_shape()[0].value
    voc_size = pred.get_shape()[1].value
    rg = tf.range(0, n)
    inds = tf.transpose(tf.pack([rg, tf.cast(gold, 'int32')]))
    vals = tf.ones([n])
    # gold_mat = tf.SparseTensor( , [n, voc_size])
    gold_mat = tf.sparse_to_dense(inds, [n, voc_size], vals)
    return tf.nn.softmax_cross_entropy_with_logits(pred, gold_mat)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号