def smart_cond(pred, fn1, fn2, name=None):
"""Return either fn1() or fn2() based on the boolean predicate/value `pred`.
If `pred` is bool or has a constant value it would use `static_cond`,
otherwise it would use `tf.cond`.
Args:
pred: A scalar determining whether to return the result of `fn1` or `fn2`.
fn1: The callable to be performed if pred is true.
fn2: The callable to be performed if pred is false.
name: Optional name prefix when using tf.cond
Returns:
Tensors returned by the call to either `fn1` or `fn2`.
"""
pred_value = constant_value(pred)
if pred_value is not None:
# Use static_cond if pred has a constant value.
return static_cond(pred_value, fn1, fn2)
else:
# Use dynamic cond otherwise.
return control_flow_ops.cond(pred, fn1, fn2, name)
评论列表
文章目录