python类cond()的实例源码

utils.py 文件源码 项目:opinatt 作者: epochx 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def _step(time, sequence_length, min_sequence_length, max_sequence_length, zero_logit, generate_logit):
  # Step 1: determine whether we need to call_cell or not
  empty_update = lambda: zero_logit
  logit = control_flow_ops.cond(
    time < max_sequence_length, generate_logit, empty_update)

  # Step 2: determine whether we need to copy through state and/or outputs
  existing_logit = lambda: logit

  def copy_through():
    # Use broadcasting select to determine which values should get
    # the previous state & zero output, and which values should get
    # a calculated state & output.
    copy_cond = (time >= sequence_length)
    return math_ops.select(copy_cond, zero_logit, logit)

  logit = control_flow_ops.cond(
    time < min_sequence_length, existing_logit, copy_through)
  logit.set_shape(logit.get_shape())
  return logit
util.py 文件源码 项目:tefla 作者: litan 项目源码 文件源码 阅读 18 收藏 0 点赞 0 评论 0
def static_cond(pred, fn1, fn2):
    """Return either fn1() or fn2() based on the boolean value of `pred`.

    Same signature as `control_flow_ops.cond()` but requires pred to be a bool.

    Args:
        pred: A value determining whether to return the result of `fn1` or `fn2`.
        fn1: The callable to be performed if pred is true.
        fn2: The callable to be performed if pred is false.

    Returns:
        Tensors returned by the call to either `fn1` or `fn2`.

    Raises:
        TypeError: if `fn1` or `fn2` is not callable.
    """
    if not callable(fn1):
        raise TypeError('fn1 must be callable.')
    if not callable(fn2):
        raise TypeError('fn2 must be callable.')
    if pred:
        return fn1()
    else:
        return fn2()
util.py 文件源码 项目:tefla 作者: litan 项目源码 文件源码 阅读 17 收藏 0 点赞 0 评论 0
def smart_cond(pred, fn1, fn2, name=None):
    """Return either fn1() or fn2() based on the boolean predicate/value `pred`.

    If `pred` is bool or has a constant value it would use `static_cond`,
     otherwise it would use `tf.cond`.

    Args:
        pred: A scalar determining whether to return the result of `fn1` or `fn2`.
        fn1: The callable to be performed if pred is true.
        fn2: The callable to be performed if pred is false.
        name: Optional name prefix when using tf.cond
    Returns:
        Tensors returned by the call to either `fn1` or `fn2`.
    """
    pred_value = constant_value(pred)
    if pred_value is not None:
        # Use static_cond if pred has a constant value.
        return static_cond(pred_value, fn1, fn2)
    else:
        # Use dynamic cond otherwise.
        return control_flow_ops.cond(pred, fn1, fn2, name)
seq_labeling.py 文件源码 项目:joint-slu-lm 作者: HadoopIt 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def _step(time, sequence_length, min_sequence_length, max_sequence_length, zero_logit, generate_logit):
  # Step 1: determine whether we need to call_cell or not
  empty_update = lambda: zero_logit
  logit = control_flow_ops.cond(
      time < max_sequence_length, generate_logit, empty_update)

  # Step 2: determine whether we need to copy through state and/or outputs
  existing_logit = lambda: logit

  def copy_through():
    # Use broadcasting select to determine which values should get
    # the previous state & zero output, and which values should get
    # a calculated state & output.
    copy_cond = (time >= sequence_length)
    return math_ops.select(copy_cond, zero_logit, logit)

  logit = control_flow_ops.cond(
      time < min_sequence_length, existing_logit, copy_through)
  logit.set_shape(logit.get_shape())
  return logit
tensor_forest.py 文件源码 项目:deep-learning 作者: lbkchen 项目源码 文件源码 阅读 18 收藏 0 点赞 0 评论 0
def average_impurity(self):
    """Constructs a TF graph for evaluating the average leaf impurity of a tree.

    If in regression mode, this is the leaf variance. If in classification mode,
    this is the gini impurity.

    Returns:
      The last op in the graph.
    """
    children = array_ops.squeeze(array_ops.slice(
        self.variables.tree, [0, 0], [-1, 1]), squeeze_dims=[1])
    is_leaf = math_ops.equal(constants.LEAF_NODE, children)
    leaves = math_ops.to_int32(array_ops.squeeze(array_ops.where(is_leaf),
                                                 squeeze_dims=[1]))
    counts = array_ops.gather(self.variables.node_sums, leaves)
    gini = self._weighted_gini(counts)
    # Guard against step 1, when there often are no leaves yet.
    def impurity():
      return gini
    # Since average impurity can be used for loss, when there's no data just
    # return a big number so that loss always decreases.
    def big():
      return array_ops.ones_like(gini, dtype=dtypes.float32) * 10000000.
    return control_flow_ops.cond(math_ops.greater(
        array_ops.shape(leaves)[0], 0), impurity, big)
tf_image.py 文件源码 项目:SSD_tensorflow_VOC 作者: LevinJ 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def _assert(cond, ex_type, msg):
    """A polymorphic assert, works with tensors and boolean expressions.
    If `cond` is not a tensor, behave like an ordinary assert statement, except
    that a empty list is returned. If `cond` is a tensor, return a list
    containing a single TensorFlow assert op.
    Args:
      cond: Something evaluates to a boolean value. May be a tensor.
      ex_type: The exception class to use.
      msg: The error message.
    Returns:
      A list, containing at most one assert op.
    """
    if _is_tensor(cond):
        return [control_flow_ops.Assert(cond, [msg])]
    else:
        if not cond:
            raise ex_type(msg)
        else:
            return []
util.py 文件源码 项目:tefla 作者: openAGI 项目源码 文件源码 阅读 17 收藏 0 点赞 0 评论 0
def static_cond(pred, fn1, fn2):
    """Return either fn1() or fn2() based on the boolean value of `pred`.

    Same signature as `control_flow_ops.cond()` but requires pred to be a bool.

    Args:
        pred: A value determining whether to return the result of `fn1` or `fn2`.
        fn1: The callable to be performed if pred is true.
        fn2: The callable to be performed if pred is false.

    Returns:
        Tensors returned by the call to either `fn1` or `fn2`.

    Raises:
        TypeError: if `fn1` or `fn2` is not callable.
    """
    if not callable(fn1):
        raise TypeError('fn1 must be callable.')
    if not callable(fn2):
        raise TypeError('fn2 must be callable.')
    if pred:
        return fn1()
    else:
        return fn2()
util.py 文件源码 项目:tefla 作者: openAGI 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def smart_cond(pred, fn1, fn2, name=None):
    """Return either fn1() or fn2() based on the boolean predicate/value `pred`.

    If `pred` is bool or has a constant value it would use `static_cond`,
     otherwise it would use `tf.cond`.

    Args:
        pred: A scalar determining whether to return the result of `fn1` or `fn2`.
        fn1: The callable to be performed if pred is true.
        fn2: The callable to be performed if pred is false.
        name: Optional name prefix when using tf.cond
    Returns:
        Tensors returned by the call to either `fn1` or `fn2`.
    """
    pred_value = constant_value(pred)
    if pred_value is not None:
        # Use static_cond if pred has a constant value.
        return static_cond(pred_value, fn1, fn2)
    else:
        # Use dynamic cond otherwise.
        return control_flow_ops.cond(pred, fn1, fn2, name)
utils.py 文件源码 项目:lsdc 作者: febert 项目源码 文件源码 阅读 18 收藏 0 点赞 0 评论 0
def static_cond(pred, fn1, fn2):
  """Return either fn1() or fn2() based on the boolean value of `pred`.

  Same signature as `control_flow_ops.cond()` but requires pred to be a bool.

  Args:
    pred: A value determining whether to return the result of `fn1` or `fn2`.
    fn1: The callable to be performed if pred is true.
    fn2: The callable to be performed if pred is false.

  Returns:
    Tensors returned by the call to either `fn1` or `fn2`.

  Raises:
    TypeError: if `fn1` or `fn2` is not callable.
  """
  if not callable(fn1):
    raise TypeError('fn1 must be callable.')
  if not callable(fn2):
    raise TypeError('fn2 must be callable.')
  if pred:
    return fn1()
  else:
    return fn2()
utils.py 文件源码 项目:lsdc 作者: febert 项目源码 文件源码 阅读 18 收藏 0 点赞 0 评论 0
def smart_cond(pred, fn1, fn2, name=None):
  """Return either fn1() or fn2() based on the boolean predicate/value `pred`.

  If `pred` is bool or has a constant value it would use `static_cond`,
  otherwise it would use `tf.cond`.

  Args:
    pred: A scalar determining whether to return the result of `fn1` or `fn2`.
    fn1: The callable to be performed if pred is true.
    fn2: The callable to be performed if pred is false.
    name: Optional name prefix when using tf.cond
  Returns:
    Tensors returned by the call to either `fn1` or `fn2`.
  """
  pred_value = constant_value(pred)
  if pred_value is not None:
    # Use static_cond if pred has a constant value.
    return static_cond(pred_value, fn1, fn2)
  else:
    # Use dynamic cond otherwise.
    return control_flow_ops.cond(pred, fn1, fn2, name)
metric_ops.py 文件源码 项目:lsdc 作者: febert 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def _safe_scalar_div(numerator, denominator, name):
  """Divides two values, returning 0 if the denominator is != 0.

  Args:
    numerator: A scalar `float64` `Tensor`.
    denominator: A scalar `float64` `Tensor`.
    name: Name for the returned op.

  Returns:
    0 if `denominator` == 0, else `numerator` / `denominator`
  """
  numerator.get_shape().with_rank_at_most(1)
  denominator.get_shape().with_rank_at_most(1)
  return control_flow_ops.cond(
      math_ops.equal(
          array_ops.constant(0.0, dtype=dtypes.float64), denominator),
      lambda: array_ops.constant(0.0, dtype=dtypes.float64),
      lambda: math_ops.div(numerator, denominator),
      name=name)
utils.py 文件源码 项目:lsdc 作者: febert 项目源码 文件源码 阅读 17 收藏 0 点赞 0 评论 0
def static_cond(pred, fn1, fn2):
  """Return either fn1() or fn2() based on the boolean value of `pred`.

  Same signature as `control_flow_ops.cond()` but requires pred to be a bool.

  Args:
    pred: A value determining whether to return the result of `fn1` or `fn2`.
    fn1: The callable to be performed if pred is true.
    fn2: The callable to be performed if pred is false.

  Returns:
    Tensors returned by the call to either `fn1` or `fn2`.

  Raises:
    TypeError: if `fn1` or `fn2` is not callable.
  """
  if not callable(fn1):
    raise TypeError('fn1 must be callable.')
  if not callable(fn2):
    raise TypeError('fn2 must be callable.')
  if pred:
    return fn1()
  else:
    return fn2()
utils.py 文件源码 项目:lsdc 作者: febert 项目源码 文件源码 阅读 19 收藏 0 点赞 0 评论 0
def smart_cond(pred, fn1, fn2, name=None):
  """Return either fn1() or fn2() based on the boolean predicate/value `pred`.

  If `pred` is bool or has a constant value it would use `static_cond`,
  otherwise it would use `tf.cond`.

  Args:
    pred: A scalar determining whether to return the result of `fn1` or `fn2`.
    fn1: The callable to be performed if pred is true.
    fn2: The callable to be performed if pred is false.
    name: Optional name prefix when using tf.cond
  Returns:
    Tensors returned by the call to either `fn1` or `fn2`.
  """
  pred_value = constant_value(pred)
  if pred_value is not None:
    # Use static_cond if pred has a constant value.
    return static_cond(pred_value, fn1, fn2)
  else:
    # Use dynamic cond otherwise.
    return control_flow_ops.cond(pred, fn1, fn2, name)
tensor_forest.py 文件源码 项目:lsdc 作者: febert 项目源码 文件源码 阅读 17 收藏 0 点赞 0 评论 0
def _get_loss(self, features, labels, data_spec=None):
    """Constructs, caches, and returns the inference-based loss."""
    if self._loss is not None:
      return self._loss

    def _average_loss():
      probs = self.inference_graph(features, data_spec=data_spec)
      return math_ops.reduce_sum(self.loss_fn(
          probs, labels)) / math_ops.to_float(
              array_ops.shape(features)[0])

    self._loss = control_flow_ops.cond(
        self.average_size() > 0, _average_loss,
        lambda: constant_op.constant(sys.maxsize, dtype=dtypes.float32))

    return self._loss
tensor_forest.py 文件源码 项目:lsdc 作者: febert 项目源码 文件源码 阅读 18 收藏 0 点赞 0 评论 0
def average_impurity(self):
    """Constructs a TF graph for evaluating the average leaf impurity of a tree.

    If in regression mode, this is the leaf variance. If in classification mode,
    this is the gini impurity.

    Returns:
      The last op in the graph.
    """
    children = array_ops.squeeze(array_ops.slice(
        self.variables.tree, [0, 0], [-1, 1]), squeeze_dims=[1])
    is_leaf = math_ops.equal(constants.LEAF_NODE, children)
    leaves = math_ops.to_int32(array_ops.squeeze(array_ops.where(is_leaf),
                                                 squeeze_dims=[1]))
    counts = array_ops.gather(self.variables.node_sums, leaves)
    gini = self._weighted_gini(counts)
    # Guard against step 1, when there often are no leaves yet.
    def impurity():
      return gini
    # Since average impurity can be used for loss, when there's no data just
    # return a big number so that loss always decreases.
    def big():
      return array_ops.ones_like(gini, dtype=dtypes.float32) * 10000000.
    return control_flow_ops.cond(math_ops.greater(
        array_ops.shape(leaves)[0], 0), impurity, big)
datasets.py 文件源码 项目:self-supervision 作者: gustavla 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def _assert(cond, ex_type, msg):
  """A polymorphic assert, works with tensors and boolean expressions.

  If `cond` is not a tensor, behave like an ordinary assert statement, except
  that a empty list is returned. If `cond` is a tensor, return a list
  containing a single TensorFlow assert op.

  Args:
    cond: Something evaluates to a boolean value. May be a tensor.
    ex_type: The exception class to use.
    msg: The error message.

  Returns:
    A list, containing at most one assert op.
  """
  if is_tensor(cond):
    return [logging_ops.Assert(cond, [msg])]
  else:
    if not cond:
      raise ex_type(msg)
    else:
      return []
preprocess.py 文件源码 项目:reslearn 作者: mackcmillion 项目源码 文件源码 阅读 15 收藏 0 点赞 0 评论 0
def _resize_aux(image, new_shorter_edge_tensor):
    shape = tf.shape(image)
    height = shape[0]
    width = shape[1]

    height_smaller_than_width = tf.less_equal(height, width)
    new_height_and_width = cf.cond(
            height_smaller_than_width,
            lambda: (new_shorter_edge_tensor, _compute_longer_edge(height, width, new_shorter_edge_tensor)),
            lambda: (_compute_longer_edge(width, height, new_shorter_edge_tensor), new_shorter_edge_tensor)
    )

    # workaround since tf.image.resize_images() does not work
    image = tf.expand_dims(image, 0)
    image = tf.image.resize_bilinear(image, tf.pack(new_height_and_width))
    return tf.squeeze(image, [0])
tf_image.py 文件源码 项目:MobileNet 作者: Zehaos 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def _assert(cond, ex_type, msg):
    """A polymorphic assert, works with tensors and boolean expressions.
    If `cond` is not a tensor, behave like an ordinary assert statement, except
    that a empty list is returned. If `cond` is a tensor, return a list
    containing a single TensorFlow assert op.
    Args:
      cond: Something evaluates to a boolean value. May be a tensor.
      ex_type: The exception class to use.
      msg: The error message.
    Returns:
      A list, containing at most one assert op.
    """
    if _is_tensor(cond):
        return [control_flow_ops.Assert(cond, [msg])]
    else:
        if not cond:
            raise ex_type(msg)
        else:
            return []
tf_image.py 文件源码 项目:MobileNet 作者: Zehaos 项目源码 文件源码 阅读 17 收藏 0 点赞 0 评论 0
def random_flip_left_right(image, bboxes, seed=None):
    """Random flip left-right of an image and its bounding boxes.
    """
    def flip_bboxes(bboxes):
        """Flip bounding boxes coordinates.
        """
        bboxes = tf.stack([bboxes[:, 0], 1 - bboxes[:, 3],
                           bboxes[:, 2], 1 - bboxes[:, 1]], axis=-1)
        return bboxes

    # Random flip. Tensorflow implementation.
    with tf.name_scope('random_flip_left_right'):
        image = ops.convert_to_tensor(image, name='image')
        _Check3DImage(image, require_static=False)
        uniform_random = random_ops.random_uniform([], 0, 1.0, seed=seed)
        mirror_cond = math_ops.less(uniform_random, .5)
        # Flip image.
        result = control_flow_ops.cond(mirror_cond,
                                       lambda: array_ops.reverse_v2(image, [1]),
                                       lambda: image)
        # Flip bboxes.
        bboxes = control_flow_ops.cond(mirror_cond,
                                       lambda: flip_bboxes(bboxes),
                                       lambda: bboxes)
        return fix_image_flip_shape(image, result), bboxes
seq2seq.py 文件源码 项目:seqGan_chatbot 作者: zpppy 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def _argmax_or_mcsearch(embedding, output_projection=None, update_embedding=True, mc_search=False):
    def loop_function(prev, _):
        if output_projection is not None:
            prev = nn_ops.xw_plus_b(prev, output_projection[0], output_projection[1])


        if isinstance(mc_search, bool):
            #tf.multinomial???prev?????????  ?-1??????????
            prev_symbol = tf.reshape(tf.multinomial(prev, 1), [-1]) if mc_search else math_ops.argmax(prev, 1)
        else:
            prev_symbol = tf.cond(mc_search, lambda: tf.reshape(tf.multinomial(prev, 1), [-1]), lambda: tf.argmax(prev, 1))


        emb_prev = embedding_ops.embedding_lookup(embedding, prev_symbol)
        #???????????
        if not update_embedding:
            emb_prev = array_ops.stop_gradient(emb_prev)
        return emb_prev
    return loop_function
model.py 文件源码 项目:tensorflow-cnn-finetune 作者: dgurkaynak 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def bn(x, is_training):
    x_shape = x.get_shape()
    params_shape = x_shape[-1:]

    axis = list(range(len(x_shape) - 1))

    beta = _get_variable('beta', params_shape, initializer=tf.zeros_initializer())
    gamma = _get_variable('gamma', params_shape, initializer=tf.ones_initializer())

    moving_mean = _get_variable('moving_mean', params_shape, initializer=tf.zeros_initializer(), trainable=False)
    moving_variance = _get_variable('moving_variance', params_shape, initializer=tf.ones_initializer(), trainable=False)

    # These ops will only be preformed when training.
    mean, variance = tf.nn.moments(x, axis)
    update_moving_mean = moving_averages.assign_moving_average(moving_mean, mean, BN_DECAY)
    update_moving_variance = moving_averages.assign_moving_average(moving_variance, variance, BN_DECAY)
    tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_mean)
    tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_variance)

    mean, variance = control_flow_ops.cond(
        is_training, lambda: (mean, variance),
        lambda: (moving_mean, moving_variance))

    return tf.nn.batch_normalization(x, mean, variance, beta, gamma, BN_EPSILON)
seq_labeling.py 文件源码 项目:rnn-nlu 作者: HadoopIt 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def _step(time, sequence_length, min_sequence_length, 
          max_sequence_length, zero_logit, generate_logit):
  # Step 1: determine whether we need to call_cell or not
  empty_update = lambda: zero_logit
  logit = control_flow_ops.cond(
      time < max_sequence_length, generate_logit, empty_update)

  # Step 2: determine whether we need to copy through state and/or outputs
  existing_logit = lambda: logit

  def copy_through():
    # Use broadcasting select to determine which values should get
    # the previous state & zero output, and which values should get
    # a calculated state & output.
    copy_cond = (time >= sequence_length)
    return tf.where(copy_cond, zero_logit, logit)

  logit = control_flow_ops.cond(
      time < min_sequence_length, existing_logit, copy_through)
  logit.set_shape(zero_logit.get_shape())
  return logit
utils.py 文件源码 项目:emoatt 作者: epochx 项目源码 文件源码 阅读 18 收藏 0 点赞 0 评论 0
def _step(time, sequence_length, min_sequence_length, max_sequence_length, zero_logit, generate_logit):
  # Step 1: determine whether we need to call_cell or not
  empty_update = lambda: zero_logit
  logit = control_flow_ops.cond(
    time < max_sequence_length, generate_logit, empty_update)

  # Step 2: determine whether we need to copy through state and/or outputs
  existing_logit = lambda: logit

  def copy_through():
    # Use broadcasting select to determine which values should get
    # the previous state & zero output, and which values should get
    # a calculated state & output.
    copy_cond = (time >= sequence_length)
    return math_ops.select(copy_cond, zero_logit, logit)

  logit = control_flow_ops.cond(
    time < min_sequence_length, existing_logit, copy_through)
  logit.set_shape(logit.get_shape())
  return logit
tf_image.py 文件源码 项目:seglink 作者: dengdan 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def _assert(cond, ex_type, msg):
    """A polymorphic assert, works with tensors and boolean expressions.
    If `cond` is not a tensor, behave like an ordinary assert statement, except
    that a empty list is returned. If `cond` is a tensor, return a list
    containing a single TensorFlow assert op.
    Args:
      cond: Something evaluates to a boolean value. May be a tensor.
      ex_type: The exception class to use.
      msg: The error message.
    Returns:
      A list, containing at most one assert op.
    """
    if _is_tensor(cond):
        return [control_flow_ops.Assert(cond, [msg])]
    else:
        if not cond:
            raise ex_type(msg)
        else:
            return []
tf_image.py 文件源码 项目:seglink 作者: dengdan 项目源码 文件源码 阅读 16 收藏 0 点赞 0 评论 0
def random_flip_left_right(image, bboxes, seed=None):
    """Random flip left-right of an image and its bounding boxes.
    """
    def flip_bboxes(bboxes):
        """Flip bounding boxes coordinates.
        """
        bboxes = tf.stack([bboxes[:, 0], 1 - bboxes[:, 3],
                           bboxes[:, 2], 1 - bboxes[:, 1]], axis=-1)
        return bboxes

    # Random flip. Tensorflow implementation.
    with tf.name_scope('random_flip_left_right'):
        image = ops.convert_to_tensor(image, name='image')
        _Check3DImage(image, require_static=False)
        uniform_random = random_ops.random_uniform([], 0, 1.0, seed=seed)
        mirror_cond = math_ops.less(uniform_random, .5)
        # Flip image.
        result = control_flow_ops.cond(mirror_cond,
                                       lambda: array_ops.reverse_v2(image, [1]),
                                       lambda: image)
        # Flip bboxes.
        bboxes = control_flow_ops.cond(mirror_cond,
                                       lambda: flip_bboxes(bboxes),
                                       lambda: bboxes)
        return fix_image_flip_shape(image, result), bboxes
tf_image.py 文件源码 项目:DAVIS-2016-Chanllege-Solution 作者: tangyuhao 项目源码 文件源码 阅读 18 收藏 0 点赞 0 评论 0
def _assert(cond, ex_type, msg):
    """A polymorphic assert, works with tensors and boolean expressions.
    If `cond` is not a tensor, behave like an ordinary assert statement, except
    that a empty list is returned. If `cond` is a tensor, return a list
    containing a single TensorFlow assert op.
    Args:
      cond: Something evaluates to a boolean value. May be a tensor.
      ex_type: The exception class to use.
      msg: The error message.
    Returns:
      A list, containing at most one assert op.
    """
    if _is_tensor(cond):
        return [control_flow_ops.Assert(cond, [msg])]
    else:
        if not cond:
            raise ex_type(msg)
        else:
            return []
tf_image.py 文件源码 项目:DAVIS-2016-Chanllege-Solution 作者: tangyuhao 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def random_flip_left_right(image, bboxes, seed=None):
    """Random flip left-right of an image and its bounding boxes.
    """
    def flip_bboxes(bboxes):
        """Flip bounding boxes coordinates.
        """
        bboxes = tf.stack([bboxes[:, 0], 1 - bboxes[:, 3],
                           bboxes[:, 2], 1 - bboxes[:, 1]], axis=-1)
        return bboxes

    # Random flip. Tensorflow implementation.
    with tf.name_scope('random_flip_left_right'):
        image = ops.convert_to_tensor(image, name='image')
        _Check3DImage(image, require_static=False)
        uniform_random = random_ops.random_uniform([], 0, 1.0, seed=seed)
        mirror_cond = math_ops.less(uniform_random, .5)
        # Flip image.
        result = control_flow_ops.cond(mirror_cond,
                                       lambda: array_ops.reverse_v2(image, [1]),
                                       lambda: image)
        # Flip bboxes.
        bboxes = control_flow_ops.cond(mirror_cond,
                                       lambda: flip_bboxes(bboxes),
                                       lambda: bboxes)
        return fix_image_flip_shape(image, result), bboxes
stochastic_graph_test.py 文件源码 项目:DeepLearning_VirtualReality_BigData_Project 作者: rashmitripathi 项目源码 文件源码 阅读 18 收藏 0 点赞 0 评论 0
def testTraversesControlInputs(self):
    dt1 = st.StochasticTensor(distributions.Normal(loc=0., scale=1.))
    logits = dt1.value() * 3.
    dt2 = st.StochasticTensor(distributions.Bernoulli(logits=logits))
    dt3 = st.StochasticTensor(distributions.Normal(loc=0., scale=1.))
    x = dt3.value()
    y = array_ops.ones((2, 2)) * 4.
    z = array_ops.ones((2, 2)) * 3.
    out = control_flow_ops.cond(
        math_ops.cast(dt2, dtypes.bool), lambda: math_ops.add(x, y),
        lambda: math_ops.square(z))
    out += 5.
    dep_map = sg._stochastic_dependencies_map([out])
    self.assertEqual(dep_map[dt1], set([out]))
    self.assertEqual(dep_map[dt2], set([out]))
    self.assertEqual(dep_map[dt3], set([out]))
utils.py 文件源码 项目:DeepLearning_VirtualReality_BigData_Project 作者: rashmitripathi 项目源码 文件源码 阅读 19 收藏 0 点赞 0 评论 0
def static_cond(pred, fn1, fn2):
  """Return either fn1() or fn2() based on the boolean value of `pred`.

  Same signature as `control_flow_ops.cond()` but requires pred to be a bool.

  Args:
    pred: A value determining whether to return the result of `fn1` or `fn2`.
    fn1: The callable to be performed if pred is true.
    fn2: The callable to be performed if pred is false.

  Returns:
    Tensors returned by the call to either `fn1` or `fn2`.

  Raises:
    TypeError: if `fn1` or `fn2` is not callable.
  """
  if not callable(fn1):
    raise TypeError('fn1 must be callable.')
  if not callable(fn2):
    raise TypeError('fn2 must be callable.')
  if pred:
    return fn1()
  else:
    return fn2()
utils.py 文件源码 项目:DeepLearning_VirtualReality_BigData_Project 作者: rashmitripathi 项目源码 文件源码 阅读 18 收藏 0 点赞 0 评论 0
def smart_cond(pred, fn1, fn2, name=None):
  """Return either fn1() or fn2() based on the boolean predicate/value `pred`.

  If `pred` is bool or has a constant value it would use `static_cond`,
  otherwise it would use `tf.cond`.

  Args:
    pred: A scalar determining whether to return the result of `fn1` or `fn2`.
    fn1: The callable to be performed if pred is true.
    fn2: The callable to be performed if pred is false.
    name: Optional name prefix when using tf.cond
  Returns:
    Tensors returned by the call to either `fn1` or `fn2`.
  """
  pred_value = constant_value(pred)
  if pred_value is not None:
    # Use static_cond if pred has a constant value.
    return static_cond(pred_value, fn1, fn2)
  else:
    # Use dynamic cond otherwise.
    return control_flow_ops.cond(pred, fn1, fn2, name)


问题


面经


文章

微信
公众号

扫码关注公众号