def clip_norm(g, c, n):
if c > 0:
if K.backend() == 'tensorflow':
import tensorflow as tf
import copy
condition = n >= c
then_expression = tf.scalar_mul(c / n, g)
else_expression = g
if hasattr(then_expression, 'get_shape'):
g_shape = copy.copy(then_expression.get_shape())
elif hasattr(then_expression, 'dense_shape'):
g_shape = copy.copy(then_expression.dense_shape)
if condition.dtype != tf.bool:
condition = tf.cast(condition, 'bool')
g = K.tensorflow_backend.control_flow_ops.cond(
condition, lambda: then_expression, lambda: else_expression)
if hasattr(then_expression, 'get_shape'):
g.set_shape(g_shape)
elif hasattr(then_expression, 'dense_shape'):
g._dense_shape = g_shape
else:
g = K.switch(n >= c, g * c / n, g)
return g
评论列表
文章目录