keras_patches.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:keras-image-captioning 作者: danieljl 项目源码 文件源码
def clip_norm(g, c, n):
    if c > 0:
        if K.backend() == 'tensorflow':
            import tensorflow as tf
            import copy
            condition = n >= c
            then_expression = tf.scalar_mul(c / n, g)
            else_expression = g

            if hasattr(then_expression, 'get_shape'):
                g_shape = copy.copy(then_expression.get_shape())
            elif hasattr(then_expression, 'dense_shape'):
                g_shape = copy.copy(then_expression.dense_shape)
            if condition.dtype != tf.bool:
                condition = tf.cast(condition, 'bool')
            g = K.tensorflow_backend.control_flow_ops.cond(
                condition, lambda: then_expression, lambda: else_expression)
            if hasattr(then_expression, 'get_shape'):
                g.set_shape(g_shape)
            elif hasattr(then_expression, 'dense_shape'):
                g._dense_shape = g_shape
        else:
            g = K.switch(n >= c, g * c / n, g)
    return g
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号