keras_extensions.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:onto-lstm 作者: pdasigi 项目源码 文件源码
def switch(condition, then_tensor, else_tensor):
    """
    Keras' implementation of switch for tensorflow uses tf.switch which accepts only scalar conditions.
    It should use tf.select instead.
    """
    if K.backend() == 'tensorflow':
        import tensorflow as tf
        condition_shape = condition.get_shape()
        input_shape = then_tensor.get_shape()
        if condition_shape[-1] != input_shape[-1] and condition_shape[-1] == 1:
            # This means the last dim is an embedding dim. Keras does not mask this dimension. But tf wants
            # the condition and the then and else tensors to be the same shape.
            condition = K.dot(tf.cast(condition, tf.float32), tf.ones((1, input_shape[-1])))
        return tf.select(tf.cast(condition, dtype=tf.bool), then_tensor, else_tensor)
    else:
        import theano.tensor as T
        return T.switch(condition, then_tensor, else_tensor)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号