python类select()的实例源码

utils.py 文件源码 项目:opinatt 作者: epochx 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def _step(time, sequence_length, min_sequence_length, max_sequence_length, zero_logit, generate_logit):
  # Step 1: determine whether we need to call_cell or not
  empty_update = lambda: zero_logit
  logit = control_flow_ops.cond(
    time < max_sequence_length, generate_logit, empty_update)

  # Step 2: determine whether we need to copy through state and/or outputs
  existing_logit = lambda: logit

  def copy_through():
    # Use broadcasting select to determine which values should get
    # the previous state & zero output, and which values should get
    # a calculated state & output.
    copy_cond = (time >= sequence_length)
    return math_ops.select(copy_cond, zero_logit, logit)

  logit = control_flow_ops.cond(
    time < min_sequence_length, existing_logit, copy_through)
  logit.set_shape(logit.get_shape())
  return logit
seq_labeling.py 文件源码 项目:joint-slu-lm 作者: HadoopIt 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def _step(time, sequence_length, min_sequence_length, max_sequence_length, zero_logit, generate_logit):
  # Step 1: determine whether we need to call_cell or not
  empty_update = lambda: zero_logit
  logit = control_flow_ops.cond(
      time < max_sequence_length, generate_logit, empty_update)

  # Step 2: determine whether we need to copy through state and/or outputs
  existing_logit = lambda: logit

  def copy_through():
    # Use broadcasting select to determine which values should get
    # the previous state & zero output, and which values should get
    # a calculated state & output.
    copy_cond = (time >= sequence_length)
    return math_ops.select(copy_cond, zero_logit, logit)

  logit = control_flow_ops.cond(
      time < min_sequence_length, existing_logit, copy_through)
  logit.set_shape(logit.get_shape())
  return logit
loss_ops.py 文件源码 项目:lsdc 作者: febert 项目源码 文件源码 阅读 17 收藏 0 点赞 0 评论 0
def _safe_div(numerator, denominator, name="value"):
  """Computes a safe divide which returns 0 if the denominator is zero.

  Note that the function contains an additional conditional check that is
  necessary for avoiding situations where the loss is zero causing NaNs to
  creep into the gradient computation.

  Args:
    numerator: An arbitrary `Tensor`.
    denominator: A `Tensor` whose shape matches `numerator` and whose values are
      assumed to be non-negative.
    name: An optional name for the returned op.

  Returns:
    The element-wise value of the numerator divided by the denominator.
  """
  return math_ops.select(
      math_ops.greater(denominator, 0),
      math_ops.div(numerator, math_ops.select(
          math_ops.equal(denominator, 0),
          array_ops.ones_like(denominator), denominator)),
      array_ops.zeros_like(numerator),
      name=name)
metric_ops.py 文件源码 项目:lsdc 作者: febert 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def _safe_div(numerator, denominator, name):
  """Divides two values, returning 0 if the denominator is <= 0.

  Args:
    numerator: A real `Tensor`.
    denominator: A real `Tensor`, with dtype matching `numerator`.
    name: Name for the returned op.

  Returns:
    0 if `denominator` <= 0, else `numerator` / `denominator`
  """
  return math_ops.select(
      math_ops.greater(denominator, 0),
      math_ops.truediv(numerator, denominator),
      0,
      name=name)
metric_ops.py 文件源码 项目:lsdc 作者: febert 项目源码 文件源码 阅读 18 收藏 0 点赞 0 评论 0
def _maybe_select_class_id(labels, predictions_idx, selected_id=None):
  """If class ID is specified, filter all other classes.

  Args:
    labels: `int64` `Tensor` or `SparseTensor` with shape
      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
      target classes for the associated prediction. Commonly, N=1 and `labels`
      has shape [batch_size, num_labels]. [D1, ... DN] must match
      `predictions_idx`.
    predictions_idx: `int64` `Tensor` of class IDs, with shape [D1, ... DN, k]
      where N >= 1. Commonly, N=1 and `predictions_idx` has shape
      [batch size, k].
    selected_id: Int id to select.

  Returns:
    Tuple of `labels` and `predictions_idx`, possibly with classes removed.
  """
  if selected_id is None:
    return labels, predictions_idx
  return (_select_class_id(labels, selected_id),
          _select_class_id(predictions_idx, selected_id))
quantized_distribution.py 文件源码 项目:lsdc 作者: febert 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def _sample_n(self, n, seed=None):
    lower_cutoff = self._lower_cutoff
    upper_cutoff = self._upper_cutoff
    with ops.name_scope("transform"):
      n = ops.convert_to_tensor(n, name="n")
      x_samps = self.base_distribution.sample_n(n=n, seed=seed)
      ones = array_ops.ones_like(x_samps)

      # Snap values to the intervals (j - 1, j].
      result_so_far = math_ops.ceil(x_samps)

      if lower_cutoff is not None:
        result_so_far = math_ops.select(result_so_far < lower_cutoff,
                                        lower_cutoff * ones, result_so_far)

      if upper_cutoff is not None:
        result_so_far = math_ops.select(result_so_far > upper_cutoff,
                                        upper_cutoff * ones, result_so_far)

      return result_so_far
quantized_distribution.py 文件源码 项目:lsdc 作者: febert 项目源码 文件源码 阅读 18 收藏 0 点赞 0 评论 0
def _log_prob_with_logsf_and_logcdf(self, y):
    # There are two options that would be equal if we had infinite precision:
    # Log[ sf(y - 1) - sf(y) ]
    #   = Log[ exp{logsf(y - 1)} - exp{logsf(y)} ]
    # Log[ cdf(y) - cdf(y - 1) ]
    #   = Log[ exp{logcdf(y)} - exp{logcdf(y - 1)} ]
    logsf_y = self.log_survival_function(y)
    logsf_y_minus_1 = self.log_survival_function(y - 1)
    logcdf_y = self.log_cdf(y)
    logcdf_y_minus_1 = self.log_cdf(y - 1)

    # Important:  Here we use select in a way such that no input is inf, this
    # prevents the troublesome case where the output of select can be finite,
    # but the output of grad(select) will be NaN.

    # In either case, we are doing Log[ exp{big} - exp{small} ]
    # We want to use the sf items precisely when we are on the right side of the
    # median, which occurs when logsf_y < logcdf_y.
    big = math_ops.select(logsf_y < logcdf_y, logsf_y_minus_1, logcdf_y)
    small = math_ops.select(logsf_y < logcdf_y, logsf_y, logcdf_y_minus_1)

    return _logsum_expbig_minus_expsmall(big, small)
student_t.py 文件源码 项目:lsdc 作者: febert 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def _variance(self):
    var = (self._ones() *
           math_ops.square(self.sigma) * self.df / (self.df - 2))
    # When 1 < df <= 2, variance is infinite.
    inf = np.array(np.inf, dtype=self.dtype.as_numpy_dtype())
    result_where_defined = math_ops.select(
        math_ops.greater(self.df, array_ops.fill(self.batch_shape(), 2.)),
        var,
        array_ops.fill(self.batch_shape(), inf, name="inf"))

    if self.allow_nan_stats:
      nan = np.array(np.nan, dtype=self.dtype.as_numpy_dtype())
      return math_ops.select(
          math_ops.greater(self.df, self._ones()),
          result_where_defined,
          array_ops.fill(self.batch_shape(), nan, name="nan"))
    else:
      return control_flow_ops.with_dependencies([
          check_ops.assert_less(
              array_ops.ones((), dtype=self.dtype), self.df,
              message="variance not defined for components of df <= 1"),
      ], result_where_defined)
beta.py 文件源码 项目:lsdc 作者: febert 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def _mode(self):
    mode = (self.a - 1.)/ (self.a_b_sum - 2.)
    if self.allow_nan_stats:
      nan = np.array(np.nan, dtype=self.dtype.as_numpy_dtype())
      return math_ops.select(
          math_ops.logical_and(
              math_ops.greater(self.a, 1.),
              math_ops.greater(self.b, 1.)),
          mode,
          array_ops.fill(self.batch_shape(), nan, name="nan"))
    else:
      return control_flow_ops.with_dependencies([
          check_ops.assert_less(
              array_ops.ones((), dtype=self.dtype), self.a,
              message="Mode not defined for components of a <= 1."),
          check_ops.assert_less(
              array_ops.ones((), dtype=self.dtype), self.b,
              message="Mode not defined for components of b <= 1."),
      ], mode)
dirichlet.py 文件源码 项目:lsdc 作者: febert 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def _mode(self):
    mode = ((self.alpha - 1.) /
            (array_ops.expand_dims(self.alpha_sum, dim=-1) -
             math_ops.cast(self.event_shape()[0], self.dtype)))
    if self.allow_nan_stats:
      nan = np.array(np.nan, dtype=self.dtype.as_numpy_dtype())
      shape = array_ops.concat(0, (self.batch_shape(), self.event_shape()))
      return math_ops.select(
          math_ops.greater(self.alpha, 1.),
          mode,
          array_ops.fill(shape, nan, name="nan"))
    else:
      return control_flow_ops.with_dependencies([
          check_ops.assert_less(
              array_ops.ones((), dtype=self.dtype), self.alpha,
              message="mode not defined for components of alpha <= 1")
      ], mode)
metric_ops.py 文件源码 项目:lsdc 作者: febert 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def _safe_div(numerator, denominator, name):
  """Divides two values, returning 0 if the denominator is <= 0.

  Args:
    numerator: A real `Tensor`.
    denominator: A real `Tensor`, with dtype matching `numerator`.
    name: Name for the returned op.

  Returns:
    0 if `denominator` <= 0, else `numerator` / `denominator`
  """
  return math_ops.select(
      math_ops.greater(denominator, 0),
      math_ops.truediv(numerator, denominator),
      0,
      name=name)
metric_ops.py 文件源码 项目:lsdc 作者: febert 项目源码 文件源码 阅读 18 收藏 0 点赞 0 评论 0
def _maybe_select_class_id(labels, predictions_idx, selected_id=None):
  """If class ID is specified, filter all other classes.

  Args:
    labels: `int64` `Tensor` or `SparseTensor` with shape
      [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
      target classes for the associated prediction. Commonly, N=1 and `labels`
      has shape [batch_size, num_labels]. [D1, ... DN] must match
      `predictions_idx`.
    predictions_idx: `int64` `Tensor` of class IDs, with shape [D1, ... DN, k]
      where N >= 1. Commonly, N=1 and `predictions_idx` has shape
      [batch size, k].
    selected_id: Int id to select.

  Returns:
    Tuple of `labels` and `predictions_idx`, possibly with classes removed.
  """
  if selected_id is None:
    return labels, predictions_idx
  return (_select_class_id(labels, selected_id),
          _select_class_id(predictions_idx, selected_id))
quantized_distribution.py 文件源码 项目:lsdc 作者: febert 项目源码 文件源码 阅读 19 收藏 0 点赞 0 评论 0
def _sample_n(self, n, seed=None):
    lower_cutoff = self._lower_cutoff
    upper_cutoff = self._upper_cutoff
    with ops.name_scope("transform"):
      n = ops.convert_to_tensor(n, name="n")
      x_samps = self.distribution.sample_n(n=n, seed=seed)
      ones = array_ops.ones_like(x_samps)

      # Snap values to the intervals (j - 1, j].
      result_so_far = math_ops.ceil(x_samps)

      if lower_cutoff is not None:
        result_so_far = math_ops.select(result_so_far < lower_cutoff,
                                        lower_cutoff * ones, result_so_far)

      if upper_cutoff is not None:
        result_so_far = math_ops.select(result_so_far > upper_cutoff,
                                        upper_cutoff * ones, result_so_far)

      return result_so_far
student_t.py 文件源码 项目:lsdc 作者: febert 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def _variance(self):
    var = (self._ones() *
           math_ops.square(self.sigma) * self.df / (self.df - 2))
    # When 1 < df <= 2, variance is infinite.
    inf = np.array(np.inf, dtype=self.dtype.as_numpy_dtype())
    result_where_defined = math_ops.select(
        math_ops.greater(self.df, array_ops.fill(self.batch_shape(), 2.)),
        var,
        array_ops.fill(self.batch_shape(), inf, name="inf"))

    if self.allow_nan_stats:
      nan = np.array(np.nan, dtype=self.dtype.as_numpy_dtype())
      return math_ops.select(
          math_ops.greater(self.df, self._ones()),
          result_where_defined,
          array_ops.fill(self.batch_shape(), nan, name="nan"))
    else:
      return control_flow_ops.with_dependencies([
          check_ops.assert_less(
              array_ops.ones((), dtype=self.dtype), self.df,
              message="variance not defined for components of df <= 1"),
      ], result_where_defined)
beta.py 文件源码 项目:lsdc 作者: febert 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def _mode(self):
    mode = (self.a - 1.)/ (self.a_b_sum - 2.)
    if self.allow_nan_stats:
      nan = np.array(np.nan, dtype=self.dtype.as_numpy_dtype())
      return math_ops.select(
          math_ops.logical_and(
              math_ops.greater(self.a, 1.),
              math_ops.greater(self.b, 1.)),
          mode,
          array_ops.fill(self.batch_shape(), nan, name="nan"))
    else:
      return control_flow_ops.with_dependencies([
          check_ops.assert_less(
              array_ops.ones((), dtype=self.dtype), self.a,
              message="Mode not defined for components of a <= 1."),
          check_ops.assert_less(
              array_ops.ones((), dtype=self.dtype), self.b,
              message="Mode not defined for components of b <= 1."),
      ], mode)
utils.py 文件源码 项目:emoatt 作者: epochx 项目源码 文件源码 阅读 17 收藏 0 点赞 0 评论 0
def _step(time, sequence_length, min_sequence_length, max_sequence_length, zero_logit, generate_logit):
  # Step 1: determine whether we need to call_cell or not
  empty_update = lambda: zero_logit
  logit = control_flow_ops.cond(
    time < max_sequence_length, generate_logit, empty_update)

  # Step 2: determine whether we need to copy through state and/or outputs
  existing_logit = lambda: logit

  def copy_through():
    # Use broadcasting select to determine which values should get
    # the previous state & zero output, and which values should get
    # a calculated state & output.
    copy_cond = (time >= sequence_length)
    return math_ops.select(copy_cond, zero_logit, logit)

  logit = control_flow_ops.cond(
    time < min_sequence_length, existing_logit, copy_through)
  logit.set_shape(logit.get_shape())
  return logit
special_math.py 文件源码 项目:lsdc 作者: febert 项目源码 文件源码 阅读 17 收藏 0 点赞 0 评论 0
def _ndtr(x):
  """Implements ndtr core logic."""
  half_sqrt_2 = constant_op.constant(
      0.5 * math.sqrt(2.), dtype=x.dtype, name="half_sqrt_2")
  w = x * half_sqrt_2
  z = math_ops.abs(w)
  y = math_ops.select(math_ops.less(z, half_sqrt_2),
                      1. + math_ops.erf(w),
                      math_ops.select(math_ops.greater(w, 0.),
                                      2. - math_ops.erfc(z),
                                      math_ops.erfc(z)))
  return 0.5 * y
inverse_gamma.py 文件源码 项目:lsdc 作者: febert 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def _mean(self):
    mean = self.beta / (self.alpha - 1.)
    if self.allow_nan_stats:
      nan = np.array(np.nan, dtype=self.dtype.as_numpy_dtype())
      return math_ops.select(
          self.alpha > 1., mean,
          array_ops.fill(self.batch_shape(), nan, name="nan"))
    else:
      return control_flow_ops.with_dependencies([
          check_ops.assert_less(
              array_ops.ones((), self.dtype), self.alpha,
              message="mean not defined for components of self.alpha <= 1"),
      ], mean)
inverse_gamma.py 文件源码 项目:lsdc 作者: febert 项目源码 文件源码 阅读 19 收藏 0 点赞 0 评论 0
def _variance(self):
    var = (math_ops.square(self.beta) /
           (math_ops.square(self.alpha - 1.) * (self.alpha - 2.)))
    if self.allow_nan_stats:
      nan = np.array(np.nan, dtype=self.dtype.as_numpy_dtype())
      return math_ops.select(
          self.alpha > 2., var,
          array_ops.fill(self.batch_shape(), nan, name="nan"))
    else:
      return control_flow_ops.with_dependencies([
          check_ops.assert_less(
              constant_op.constant(2., dtype=self.dtype), self.alpha,
              message="variance not defined for components of alpha <= 2"),
      ], var)
quantized_distribution.py 文件源码 项目:lsdc 作者: febert 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def _prob_with_sf_and_cdf(self, y):
    # There are two options that would be equal if we had infinite precision:
    # sf(y - 1) - sf(y)
    # cdf(y) - cdf(y - 1)
    sf_y = self.survival_function(y)
    sf_y_minus_1 = self.survival_function(y - 1)
    cdf_y = self.cdf(y)
    cdf_y_minus_1 = self.cdf(y - 1)

    # sf_prob has greater precision iff we're on the right side of the median.
    return math_ops.select(
        sf_y < cdf_y,  # True iff we're on the right side of the median.
        sf_y_minus_1 - sf_y,
        cdf_y - cdf_y_minus_1)
quantized_distribution.py 文件源码 项目:lsdc 作者: febert 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def _log_cdf(self, y):
    lower_cutoff = self._lower_cutoff
    upper_cutoff = self._upper_cutoff

    # Recall the promise:
    # cdf(y) := P[Y <= y]
    #         = 1, if y >= upper_cutoff,
    #         = 0, if y < lower_cutoff,
    #         = P[X <= y], otherwise.

    # P[Y <= j] = P[floor(Y) <= j] since mass is only at integers, not in
    # between.
    j = math_ops.floor(y)

    result_so_far = self.base_distribution.log_cdf(j)

    # Broadcast, because it's possible that this is a single distribution being
    # evaluated on a number of samples, or something like that.
    j += array_ops.zeros_like(result_so_far)

    # Re-define values at the cutoffs.
    if lower_cutoff is not None:
      neg_inf = -np.inf * array_ops.ones_like(result_so_far)
      result_so_far = math_ops.select(j < lower_cutoff, neg_inf, result_so_far)
    if upper_cutoff is not None:
      result_so_far = math_ops.select(j >= upper_cutoff,
                                      array_ops.zeros_like(result_so_far),
                                      result_so_far)

    return result_so_far
quantized_distribution.py 文件源码 项目:lsdc 作者: febert 项目源码 文件源码 阅读 18 收藏 0 点赞 0 评论 0
def _cdf(self, y):
    lower_cutoff = self._lower_cutoff
    upper_cutoff = self._upper_cutoff

    # Recall the promise:
    # cdf(y) := P[Y <= y]
    #         = 1, if y >= upper_cutoff,
    #         = 0, if y < lower_cutoff,
    #         = P[X <= y], otherwise.

    # P[Y <= j] = P[floor(Y) <= j] since mass is only at integers, not in
    # between.
    j = math_ops.floor(y)

    # P[X <= j], used when lower_cutoff < X < upper_cutoff.
    result_so_far = self.base_distribution.cdf(j)

    # Broadcast, because it's possible that this is a single distribution being
    # evaluated on a number of samples, or something like that.
    j += array_ops.zeros_like(result_so_far)

    # Re-define values at the cutoffs.
    if lower_cutoff is not None:
      result_so_far = math_ops.select(j < lower_cutoff,
                                      array_ops.zeros_like(result_so_far),
                                      result_so_far)
    if upper_cutoff is not None:
      result_so_far = math_ops.select(j >= upper_cutoff,
                                      array_ops.ones_like(result_so_far),
                                      result_so_far)

    return result_so_far
quantized_distribution.py 文件源码 项目:lsdc 作者: febert 项目源码 文件源码 阅读 19 收藏 0 点赞 0 评论 0
def _survival_function(self, y):
    lower_cutoff = self._lower_cutoff
    upper_cutoff = self._upper_cutoff

    # Recall the promise:
    # survival_function(y) := P[Y > y]
    #                       = 0, if y >= upper_cutoff,
    #                       = 1, if y < lower_cutoff,
    #                       = P[X > y], otherwise.

    # P[Y > j] = P[ceiling(Y) > j] since mass is only at integers, not in
    # between.
    j = math_ops.ceil(y)

    # P[X > j], used when lower_cutoff < X < upper_cutoff.
    result_so_far = self.base_distribution.survival_function(j)

    # Broadcast, because it's possible that this is a single distribution being
    # evaluated on a number of samples, or something like that.
    j += array_ops.zeros_like(result_so_far)

    # Re-define values at the cutoffs.
    if lower_cutoff is not None:
      result_so_far = math_ops.select(j < lower_cutoff,
                                      array_ops.ones_like(result_so_far),
                                      result_so_far)
    if upper_cutoff is not None:
      result_so_far = math_ops.select(j >= upper_cutoff,
                                      array_ops.zeros_like(result_so_far),
                                      result_so_far)

    return result_so_far
student_t.py 文件源码 项目:lsdc 作者: febert 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def _mean(self):
    mean = self.mu * self._ones()
    if self.allow_nan_stats:
      nan = np.array(np.nan, dtype=self.dtype.as_numpy_dtype())
      return math_ops.select(
          math_ops.greater(self.df, self._ones()), mean,
          array_ops.fill(self.batch_shape(), nan, name="nan"))
    else:
      return control_flow_ops.with_dependencies([
          check_ops.assert_less(
              array_ops.ones((), dtype=self.dtype), self.df,
              message="mean not defined for components of df <= 1"),
      ], mean)
uniform.py 文件源码 项目:lsdc 作者: febert 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def _prob(self, x):
    broadcasted_x = x * array_ops.ones(self.batch_shape())
    return math_ops.select(
        math_ops.is_nan(broadcasted_x),
        broadcasted_x,
        math_ops.select(
            math_ops.logical_or(broadcasted_x < self.a,
                                broadcasted_x > self.b),
            array_ops.zeros_like(broadcasted_x),
            (1. / self.range()) * array_ops.ones_like(broadcasted_x)))
uniform.py 文件源码 项目:lsdc 作者: febert 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def _cdf(self, x):
    broadcasted_x = x * array_ops.ones(self.batch_shape())
    zeros = array_ops.zeros_like(x + self.a + self.b, dtype=self.dtype)
    ones = array_ops.ones_like(x + self.a + self.b, dtype=self.dtype)
    result_if_not_big = math_ops.select(
        x < self.a, zeros, (broadcasted_x - self.a) / self.range())
    return math_ops.select(x >= self.b, ones, result_if_not_big)
special_math.py 文件源码 项目:lsdc 作者: febert 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def _ndtr(x):
  """Implements ndtr core logic."""
  half_sqrt_2 = constant_op.constant(
      0.5 * math.sqrt(2.), dtype=x.dtype, name="half_sqrt_2")
  w = x * half_sqrt_2
  z = math_ops.abs(w)
  y = math_ops.select(math_ops.less(z, half_sqrt_2),
                      1. + math_ops.erf(w),
                      math_ops.select(math_ops.greater(w, 0.),
                                      2. - math_ops.erfc(z),
                                      math_ops.erfc(z)))
  return 0.5 * y
metric_ops.py 文件源码 项目:lsdc 作者: febert 项目源码 文件源码 阅读 19 收藏 0 点赞 0 评论 0
def _select_class_id(ids, selected_id):
  """Filter all but `selected_id` out of `ids`.

  Args:
    ids: `int64` `Tensor` or `SparseTensor` of IDs.
    selected_id: Int id to select.

  Returns:
    `SparseTensor` of same dimensions as `ids`. This contains only the entries
    equal to `selected_id`.
  """
  if isinstance(
      ids, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
    return sparse_ops.sparse_retain(
        ids, math_ops.equal(ids.values, selected_id))

  # TODO(ptucker): Make this more efficient, maybe add a sparse version of
  # tf.equal and tf.reduce_any?

  # Shape of filled IDs is the same as `ids` with the last dim collapsed to 1.
  ids_shape = array_ops.shape(ids, out_type=dtypes.int64)
  ids_last_dim = array_ops.size(ids_shape) - 1
  filled_selected_id_shape = math_ops.reduced_shape(
      ids_shape, array_ops.reshape(ids_last_dim, [1]))

  # Intersect `ids` with the selected ID.
  filled_selected_id = array_ops.fill(
      filled_selected_id_shape, math_ops.to_int64(selected_id))
  result = set_ops.set_intersection(filled_selected_id, ids)
  return sparse_tensor.SparseTensor(
      indices=result.indices, values=result.values, shape=ids_shape)
inverse_gamma.py 文件源码 项目:lsdc 作者: febert 项目源码 文件源码 阅读 19 收藏 0 点赞 0 评论 0
def _mean(self):
    mean = self.beta / (self.alpha - 1.)
    if self.allow_nan_stats:
      nan = np.array(np.nan, dtype=self.dtype.as_numpy_dtype())
      return math_ops.select(
          self.alpha > 1., mean,
          array_ops.fill(self.batch_shape(), nan, name="nan"))
    else:
      return control_flow_ops.with_dependencies([
          check_ops.assert_less(
              array_ops.ones((), self.dtype), self.alpha,
              message="mean not defined for components of self.alpha <= 1"),
      ], mean)
inverse_gamma.py 文件源码 项目:lsdc 作者: febert 项目源码 文件源码 阅读 18 收藏 0 点赞 0 评论 0
def _variance(self):
    var = (math_ops.square(self.beta) /
           (math_ops.square(self.alpha - 1.) * (self.alpha - 2.)))
    if self.allow_nan_stats:
      nan = np.array(np.nan, dtype=self.dtype.as_numpy_dtype())
      return math_ops.select(
          self.alpha > 2., var,
          array_ops.fill(self.batch_shape(), nan, name="nan"))
    else:
      return control_flow_ops.with_dependencies([
          check_ops.assert_less(
              constant_op.constant(2., dtype=self.dtype), self.alpha,
              message="variance not defined for components of alpha <= 2"),
      ], var)


问题


面经


文章

微信
公众号

扫码关注公众号