def py_func(func, inp, Tout, name=None, grad=None):
"""Redfine tf.py_func to include gradients"""
temp_name = next(tempfile._get_candidate_names())
_name = 'PyFuncGrad%s' %temp_name;
tf.RegisterGradient(_name)(grad)
g = tf.get_default_graph()
with g.gradient_override_map({"PyFunc": _name}):
return tf.py_func(func, inp, Tout, name=name)
machine_vision.py 文件源码
python
阅读 32
收藏 0
点赞 0
评论 0
评论列表
文章目录