losses.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:segmenter 作者: yanshao9798 项目源码 文件源码
def loss_wrapper(y, y_, loss_function, transitions=None, nums_tags=None, batch_size=None, weights=None, average_cross_steps=True):
    assert len(y) == len(y_)
    total_loss = []
    if loss_function is crf_loss:
        #print len(y), len(transitions), len(nums_tags)
        assert len(y) == len(transitions) and len(transitions) == len(nums_tags) and batch_size is not None
        for sy, sy_, stranstion, snums_tags in zip(y, y_, transitions, nums_tags):
            total_loss.append(loss_function(sy, sy_, stranstion, snums_tags, batch_size))
    elif loss_function is cross_entropy:
        assert len(y) == len(nums_tags)
        for sy, sy_, snums_tags in zip(y, y_, nums_tags):
            total_loss.append(loss_function(sy, sy_, snums_tags))
    elif loss_function is sparse_cross_entropy:
        for sy, sy_ in zip(y, y_):
            total_loss.append(loss_function(sy, sy_))
    elif loss_function is sparse_cross_entropy_with_weights:
        assert len(y) == len(nums_tags)
        for sy, sy_, snums_tags in zip(y, y_):
            total_loss.append(tf.reshape(loss_function(sy, sy_, weights=weights, average_cross_steps=average_cross_steps), [-1]))
    else:
        for sy, sy_ in zip(y, y_):
            total_loss.append(tf.reshape(loss_function(sy, sy_), [-1]))
    return tf.stack(total_loss)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号