loss.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:LIE 作者: EmbraceLife 项目源码 文件源码
def keras_wrap(model, target, output, loss):
    """ Convenience function for wrapping a Keras loss function.
    """
    # pylint: disable=import-error
    import keras.objectives as O
    import keras.backend as K
    # pylint: enable=import-error
    if isinstance(loss, str):
        loss = O.get(loss)
    shape = model.outputs[target].value._keras_shape # pylint: disable=protected-access
    ins = [
        (target, K.placeholder(
            ndim=len(shape),
            dtype=K.dtype(model.outputs[target].value),
            name=target
        ))
    ]
    out = loss(ins[0][1], output)
    return ins, out

###############################################################################
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号