distribution_util.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def pick_vector(cond,
                true_vector,
                false_vector,
                name="pick_vector"):
  """Picks possibly different length row `Tensor`s based on condition.

  Value `Tensor`s should have exactly one dimension.

  If `cond` is a python Boolean or `tf.constant` then either `true_vector` or
  `false_vector` is immediately returned. I.e., no graph nodes are created and
  no validation happens.

  Args:
    cond: `Tensor`. Must have `dtype=tf.bool` and be scalar.
    true_vector: `Tensor` of one dimension. Returned when cond is `True`.
    false_vector: `Tensor` of one dimension. Returned when cond is `False`.
    name: `String`. The name to give this op.

  Example:

  ```python
  pick_vector(tf.less(0, 5), tf.range(10, 12), tf.range(15, 18))
  # result is tensor: [10, 11].
  pick_vector(tf.less(5, 0), tf.range(10, 12), tf.range(15, 18))
  # result is tensor: [15, 16, 17].

Returns: true_or_false_vector: Tensor.

Raises: TypeError: if cond.dtype != tf.bool TypeError: if cond is not a constant and true_vector.dtype != false_vector.dtype """ with ops.op_scope((cond, true_vector, false_vector), name): cond = ops.convert_to_tensor(cond, name="cond") if cond.dtype != dtypes.bool: raise TypeError("%s.dtype=%s which is not %s" % (cond.name, cond.dtype, dtypes.bool)) cond_value_static = tensor_util.constant_value(cond) if cond_value_static is not None: return true_vector if cond_value_static else false_vector true_vector = ops.convert_to_tensor(true_vector, name="true_vector") false_vector = ops.convert_to_tensor(false_vector, name="false_vector") if true_vector.dtype != false_vector.dtype: raise TypeError( "%s.dtype=%s does not match %s.dtype=%s" % (true_vector.name, true_vector.dtype, false_vector.name, false_vector.dtype)) n = array_ops.shape(true_vector)[0] return array_ops.slice(array_ops.concat(0, (true_vector, false_vector)), [math_ops.select(cond, 0, n)], [math_ops.select(cond, n, -1)]) ```

评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号