def switch(condition, then_tensor, else_tensor):
"""
Keras' implementation of switch for tensorflow uses tf.switch which accepts only scalar conditions.
It should use tf.select instead.
"""
if K.backend() == 'tensorflow':
import tensorflow as tf
condition_shape = condition.get_shape()
input_shape = then_tensor.get_shape()
if condition_shape[-1] != input_shape[-1] and condition_shape[-1] == 1:
# This means the last dim is an embedding dim. Keras does not mask this dimension. But tf wants
# the condition and the then and else tensors to be the same shape.
condition = K.dot(tf.cast(condition, tf.float32), tf.ones((1, input_shape[-1])))
return tf.select(tf.cast(condition, dtype=tf.bool), then_tensor, else_tensor)
else:
import theano.tensor as T
return T.switch(condition, then_tensor, else_tensor)
评论列表
文章目录