def keras_wrap(model, target, output, loss):
""" Convenience function for wrapping a Keras loss function.
"""
# pylint: disable=import-error
import keras.objectives as O
import keras.backend as K
# pylint: enable=import-error
if isinstance(loss, str):
loss = O.get(loss)
shape = model.outputs[target].value._keras_shape # pylint: disable=protected-access
ins = [
(target, K.placeholder(
ndim=len(shape),
dtype=K.dtype(model.outputs[target].value),
name=target
))
]
out = loss(ins[0][1], output)
return ins, out
###############################################################################
评论列表
文章目录