python类float16()的实例源码

bbbc006.py 文件源码 项目:dcan-tensorflow 作者: lisjin 项目源码 文件源码 阅读 34 收藏 0 点赞 0 评论 0
def _variable_with_weight_decay(name, shape, stddev, wd):
    """Helper to create an initialized Variable with weight decay.
    Note that the Variable is initialized with a truncated normal distribution.
    A weight decay is added only if one is specified.
    Args:
        name: name of the variable
        shape: list of ints
        stddev: standard deviation of a truncated Gaussian
        wd: add L2Loss weight decay multiplied by this float. If None, weight
            decay is not added for this Variable.
    Returns:
        Variable Tensor
    """
    dtype = tf.float16 if FLAGS.use_fp16 else tf.float32
    var = _variable_on_cpu(
        name,
        shape,
        tf.truncated_normal_initializer(stddev=stddev, dtype=dtype))
    if wd is not None and not tf.get_variable_scope().reuse:
        weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, name='weight_loss')
        tf.add_to_collection('losses', weight_decay)
    return var
bbbc006.py 文件源码 项目:dcan-tensorflow 作者: lisjin 项目源码 文件源码 阅读 31 收藏 0 点赞 0 评论 0
def inputs(eval_data):
    """Construct input for BBBC006 evaluation using the Reader ops.
    Args:
        eval_data: bool, indicating if one should use the train or eval data set.
    Returns:
        images: Images. 4D tensor of [batch_size, IMAGE_WIDTH, IMAGE_HEIGHT, 1] size.
        labels: Labels. 4D tensor of [batch_size, IMAGE_WIDTH, IMAGE_HEIGHT, 2] size.

    Raises:
        ValueError: If no data_dir
    """
    if not FLAGS.data_dir:
        raise ValueError('Please supply a data_dir')
    images, labels = bbbc006_input.inputs(eval_data=eval_data,
                                          batch_size=FLAGS.batch_size)
    if FLAGS.use_fp16:
        images = tf.cast(images, tf.float16)
        labels = tf.cast(labels, tf.float16)
    return images, labels
net_model.py 文件源码 项目:3D_CNN_jonas 作者: 2015ZxEE 项目源码 文件源码 阅读 36 收藏 0 点赞 0 评论 0
def variable_with_weight_decay(name, shape, stddev, wd):
    """
    Note that the Variable is initialized with a truncated normal distribution.
    A weight decay is added only if one is specified.
    Args:
        name   -> name of the variable
        shape  -> list of ints
        stddev -> standard deviation of a truncated Gaussian
        wd     -> add L2Loss weight decay multiplied by this float.
                        If None, weight decay is not added for this Variable.
    Rtns:
        var    -> variable tensor
    """
    dtype = tf.float16 if FLAGS.use_fp16 else tf.float32
    var   = variable_on_cpu(name,shape,
                    tf.truncated_normal_initializer(stddev=stddev, dtype=dtype))
    if wd is not None:
        weight_decay = tf.mul(tf.nn.l2_loss(var),wd,name='weight_loss')
        tf.add_to_collection('losses', weight_decay)
    return var
net_model.py 文件源码 项目:3D_CNN_jonas 作者: 2015ZxEE 项目源码 文件源码 阅读 33 收藏 0 点赞 0 评论 0
def inputs_train():
    """
    Args:
        nothing
    Rtns:
        img3_batch  -> 5D float32 or float16 tensor of [batch_size,h,w,d,c]
        label_batch -> 1D float32 or float16 tensor of [batch_size]
    Raises:
        ValueError -> If no data_dir
    """
    if not FLAGS.data_dir:
        raise ValueError('Please supply a data_dir')
    data_dir                = os.path.join(FLAGS.data_dir)
    img3_batch, label_batch = in_data.inputs_train(data_dir=data_dir,
                                                    batch_size=FLAGS.batch_size)

    if FLAGS.use_fp16:
        img3_batch  = tf.cast(img3_batch, tf.float16)
        label_batch = tf.cast(label_batch, tf.float16)
    return img3_batch, label_batch
net_model.py 文件源码 项目:3D_CNN_jonas 作者: 2015ZxEE 项目源码 文件源码 阅读 39 收藏 0 点赞 0 评论 0
def inputs_eval():
    """
    Args:
        nothing
    Rtns:
        img3_batch  -> 5D float32 or float16 tensor of [batch_size,h,w,d,c]
        label_batch -> 1D float32 or float16 tensor of [batch_size]
    Raises:
        ValueError -> If no data_dir
      """
    if not FLAGS.data_dir:
        raise ValueError('Please supply a data_dir')
    data_dir               = os.path.join(FLAGS.data_dir)
    img3_batch, label_batch = in_data.inputs_eval(data_dir=data_dir,
                                                    batch_size=FLAGS.batch_size)

    if FLAGS.use_fp16:
        img3_batch   = tf.cast(img3_batch, tf.float16)
        label_batch  = tf.cast(label_batch, tf.float16)
    return img3_batch, label_batch
tensorport.py 文件源码 项目:jack 作者: uclmr 项目源码 文件源码 阅读 33 收藏 0 点赞 0 评论 0
def create_torch_variable(self, value, gpu=False):
        """Convenience method that produces a tensor given the value of the defined type.

        Returns: a torch tensor of same type.
        """
        if isinstance(value, torch.autograd.Variable):
            if gpu:
                value = value.cuda()
            return value
        if not torch.is_tensor(value):
            if not isinstance(value, np.ndarray):
                value = np.array(value, dtype=self.dtype.as_numpy_dtype)
            else:
                value = value.astype(self.dtype.as_numpy_dtype)
            if value.size == 0:
                return value
            allowed = [tf.int16, tf.int32, tf.int64, tf.float16, tf.float32, tf.float64, tf.int8]
            if self.dtype in allowed:
                value = torch.autograd.Variable(torch.from_numpy(value))
        else:
            value = torch.autograd.Variable(value)
        if gpu and isinstance(value, torch.autograd.Variable):
            value = value.cuda()
        return value
sg_optimize.py 文件源码 项目:sugartensor 作者: buriburisuri 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def _apply_dense(self, grad, var):
        lr_t = tf.cast(self._lr_t, var.dtype.base_dtype)
        beta1_t = tf.cast(self._beta1_t, var.dtype.base_dtype)
        beta2_t = tf.cast(self._beta2_t, var.dtype.base_dtype)
        if var.dtype.base_dtype == tf.float16:
            eps = 1e-7  # Can't use 1e-8 due to underflow -- not sure if it makes a big difference.
        else:
            eps = 1e-8

        v = self.get_slot(var, "v")
        v_t = v.assign(beta1_t * v + (1. - beta1_t) * grad)
        m = self.get_slot(var, "m")
        m_t = m.assign(tf.maximum(beta2_t * m + eps, tf.abs(grad)))
        g_t = v_t / m_t

        var_update = tf.assign_sub(var, lr_t * g_t)
        return tf.group(*[var_update, m_t, v_t])
tensorflow_backend.py 文件源码 项目:keras 作者: GeekLiB 项目源码 文件源码 阅读 30 收藏 0 点赞 0 评论 0
def _convert_string_dtype(dtype):
    if dtype == 'float16':
        return tf.float16
    if dtype == 'float32':
        return tf.float32
    elif dtype == 'float64':
        return tf.float64
    elif dtype == 'int16':
        return tf.int16
    elif dtype == 'int32':
        return tf.int32
    elif dtype == 'int64':
        return tf.int64
    elif dtype == 'uint8':
        return tf.int8
    elif dtype == 'uint16':
        return tf.uint16
    else:
        raise ValueError('Unsupported dtype:', dtype)
translate.py 文件源码 项目:tf-seq2seq-mod 作者: knok 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def create_model(session, forward_only):
  """Create translation model and initialize or load parameters in session."""
  dtype = tf.float16 if FLAGS.use_fp16 else tf.float32
  model = seq2seq_model.Seq2SeqModel(
      FLAGS.en_vocab_size,
      FLAGS.fr_vocab_size,
      _buckets,
      FLAGS.size,
      FLAGS.num_layers,
      FLAGS.max_gradient_norm,
      FLAGS.batch_size,
      FLAGS.learning_rate,
      FLAGS.learning_rate_decay_factor,
      use_lstm = FLAGS.use_lstm,
      forward_only=forward_only,
      dtype=dtype)
  ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
  if ckpt and tf.gfile.Exists(ckpt.model_checkpoint_path):
    print("Reading model parameters from %s" % ckpt.model_checkpoint_path)
    model.saver.restore(session, ckpt.model_checkpoint_path)
  else:
    print("Created model with fresh parameters.")
    session.run(tf.initialize_all_variables())
  return model
optimizers.py 文件源码 项目:zhusuan 作者: thu-ml 项目源码 文件源码 阅读 31 收藏 0 点赞 0 评论 0
def _apply_dense(self, grad, var):
        lr_t = tf.cast(self._lr_t, var.dtype.base_dtype)
        beta1_t = tf.cast(self._beta1_t, var.dtype.base_dtype)
        beta2_t = tf.cast(self._beta2_t, var.dtype.base_dtype)
        if var.dtype.base_dtype == tf.float16:
            # Can't use 1e-8 due to underflow
            eps = 1e-7
        else:
            eps = 1e-8

        v = self.get_slot(var, "v")
        v_t = v.assign(beta1_t * v + (1. - beta1_t) * grad)
        m = self.get_slot(var, "m")
        m_t = m.assign(tf.maximum(beta2_t * m + eps, tf.abs(grad)))
        g_t = v_t / m_t

        var_update = tf.assign_sub(var, lr_t * g_t)
        return tf.group(*[var_update, m_t, v_t])
utils.py 文件源码 项目:zhusuan 作者: thu-ml 项目源码 文件源码 阅读 29 收藏 0 点赞 0 评论 0
def assert_same_float_dtype(tensors_with_name, dtype=None):
    """
    Whether all types of tensors in `tensors` are the same and floating type.

    :param tensors_with_name: A list of (tensor, tensor_name).
    :param dtype: Expected type. If `None`, depend on the type of tensors.
    :return: The type of `tensors`.
    """

    floating_types = [tf.float16, tf.float32, tf.float64]
    if dtype is None:
        return assert_same_specific_dtype(tensors_with_name, floating_types)
    elif dtype in floating_types:
        return assert_same_dtype(tensors_with_name, dtype)
    else:
        raise TypeError("The argument 'dtype' must be in %s" % floating_types)
utils.py 文件源码 项目:IDNNs 作者: ravidziv 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def _convert_string_dtype(dtype):
    if dtype == 'float16':
        return tf.float16
    if dtype == 'float32':
        return tf.float32
    elif dtype == 'float64':
        return tf.float64
    elif dtype == 'int16':
        return tf.int16
    elif dtype == 'int32':
        return tf.int32
    elif dtype == 'int64':
        return tf.int64
    elif dtype == 'uint8':
        return tf.int8
    elif dtype == 'uint16':
        return tf.uint16
    else:
        raise ValueError('Unsupported dtype:', dtype)
cifar10_gtf.py 文件源码 项目:deep_learning_study 作者: jowettcz 项目源码 文件源码 阅读 37 收藏 0 点赞 0 评论 0
def _variable_with_weight_decay(name, shape, stddev, wd):
  """Helper to create an initialized Variable with weight decay.

  Note that the Variable is initialized with a truncated normal distribution.
  A weight decay is added only if one is specified.

  Args:
    name: name of the variable
    shape: list of ints
    stddev: standard deviation of a truncated Gaussian
    wd: add L2Loss weight decay multiplied by this float. If None, weight
        decay is not added for this Variable.

  Returns:
    Variable Tensor
  """
  dtype = tf.float16 if FLAGS.use_fp16 else tf.float32
  var = _variable_on_cpu(
      name,
      shape,
      tf.truncated_normal_initializer(stddev=stddev, dtype=dtype))
  if wd is not None:
    weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, name='weight_loss')
    tf.add_to_collection('losses', weight_decay)
  return var
cifar10_gtf.py 文件源码 项目:deep_learning_study 作者: jowettcz 项目源码 文件源码 阅读 31 收藏 0 点赞 0 评论 0
def distorted_inputs():
  """Construct distorted input for CIFAR training using the Reader ops.

  Returns:
    images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
    labels: Labels. 1D tensor of [batch_size] size.

  Raises:
    ValueError: If no data_dir
  """
  if not FLAGS.data_dir:
    raise ValueError('Please supply a data_dir')
  data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin')
  images, labels = cifar10_input.distorted_inputs(data_dir=data_dir,
                                                  batch_size=FLAGS.batch_size)
  if FLAGS.use_fp16:
    images = tf.cast(images, tf.float16)
    labels = tf.cast(labels, tf.float16)
  return images, labels
cifar10_gtf.py 文件源码 项目:deep_learning_study 作者: jowettcz 项目源码 文件源码 阅读 93 收藏 0 点赞 0 评论 0
def inputs(eval_data):
  """Construct input for CIFAR evaluation using the Reader ops.

  Args:
    eval_data: bool, indicating if one should use the train or eval data set.

  Returns:
    images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
    labels: Labels. 1D tensor of [batch_size] size.

  Raises:
    ValueError: If no data_dir
  """
  if not FLAGS.data_dir:
    raise ValueError('Please supply a data_dir')
  data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin')
  images, labels = cifar10_input.inputs(eval_data=eval_data,
                                        data_dir=data_dir,
                                        batch_size=FLAGS.batch_size)
  if FLAGS.use_fp16:
    images = tf.cast(images, tf.float16)
    labels = tf.cast(labels, tf.float16)
  return images, labels
conv_test.py 文件源码 项目:sonnet 作者: deepmind 项目源码 文件源码 阅读 31 收藏 0 点赞 0 评论 0
def testInputTypeError(self, use_bias):
    """Errors are thrown for invalid input types."""
    conv1 = snt.Conv2D(output_channels=1,
                       kernel_shape=3,
                       stride=1,
                       padding=snt.SAME,
                       name="conv1",
                       use_bias=use_bias,
                       initializers=create_constant_initializers(
                           1.0, 1.0, use_bias))

    for dtype in (tf.float16, tf.float64):
      x = tf.constant(np.ones([1, 5, 5, 1]), dtype=dtype)
      err = "Input must have dtype tf.float32.*"
      with self.assertRaisesRegexp(TypeError, err):
        conv1(x)
conv_test.py 文件源码 项目:sonnet 作者: deepmind 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def testInputTypeError(self, use_bias):
    """Errors are thrown for invalid input types."""
    conv1 = snt.Conv1D(output_channels=1,
                       kernel_shape=3,
                       stride=1,
                       padding=snt.VALID,
                       use_bias=use_bias,
                       name="conv1",
                       initializers=create_constant_initializers(
                           1.0, 1.0, use_bias))

    for dtype in (tf.float16, tf.float64):
      x = tf.constant(np.ones([1, 5, 1]), dtype=dtype)
      err = "Input must have dtype tf.float32.*"
      with self.assertRaisesRegexp(TypeError, err):
        conv1(x)
conv_test.py 文件源码 项目:sonnet 作者: deepmind 项目源码 文件源码 阅读 33 收藏 0 点赞 0 评论 0
def testInputTypeError(self, batch_size, in_length, in_channels, out_channels,
                         kernel_shape, padding, use_bias, out_shape,
                         stride_shape):
    """Errors are thrown for invalid input types."""
    conv1 = snt.Conv1DTranspose(
        output_channels=out_channels,
        output_shape=out_shape,
        kernel_shape=kernel_shape,
        padding=padding,
        stride=stride_shape,
        name="conv1",
        use_bias=use_bias)

    for dtype in (tf.float16, tf.float64):
      x = tf.constant(np.ones([batch_size, in_length,
                               in_channels]), dtype=dtype)
      err = "Input must have dtype tf.float32.*"
      with self.assertRaisesRegexp(TypeError, err):
        conv1(x)
conv_test.py 文件源码 项目:sonnet 作者: deepmind 项目源码 文件源码 阅读 32 收藏 0 点赞 0 评论 0
def testInputTypeError(self, use_bias):
    """Test that errors are thrown for invalid input types."""
    conv1 = snt.SeparableConv2D(
        output_channels=3,
        channel_multiplier=1,
        kernel_shape=3,
        padding=snt.SAME,
        use_bias=use_bias,
        initializers=create_separable_constant_initializers(
            1.0, 1.0, 1.0, use_bias))

    for dtype in (tf.float16, tf.float64):
      x = tf.constant(np.ones([1, 5, 5, 1]), dtype=dtype)
      err = "Input must have dtype tf.float32.*"
      with self.assertRaisesRegexp(TypeError, err):
        conv1(x)
basic_test.py 文件源码 项目:sonnet 作者: deepmind 项目源码 文件源码 阅读 30 收藏 0 点赞 0 评论 0
def testVariableInitialization(self):
    # Check that a simple operation involving the TrainableVariable
    # matches the result of the corresponding operation in numpy
    np.random.seed(100)
    types = (tf.float16, tf.float32, tf.float64)
    tol = (1e-2, 1e-6, 1e-9)
    tolerance_map = dict(zip(types, tol))
    lhs_shape = [3, 4]
    rhs_shape = [4, 6]
    for dtype in types:
      x = tf.placeholder(dtype, shape=lhs_shape)
      var = snt.TrainableVariable(shape=rhs_shape,
                                  dtype=dtype,
                                  initializers={"w": _test_initializer()})
      y = tf.matmul(x, var())
      with self.test_session() as sess:
        lhs_matrix = np.random.randn(*lhs_shape)
        sess.run(tf.global_variables_initializer())
        product, w = sess.run([y, var.w], {x: lhs_matrix})
      self.assertAllClose(product,
                          np.dot(
                              lhs_matrix.astype(dtype.as_numpy_dtype),
                              w.astype(dtype.as_numpy_dtype)),
                          atol=tolerance_map[dtype],
                          rtol=tolerance_map[dtype])
cifar10.py 文件源码 项目:MachineLearningTutorial 作者: SpikeKing 项目源码 文件源码 阅读 29 收藏 0 点赞 0 评论 0
def _variable_with_weight_decay(name, shape, stddev, wd):
    """Helper to create an initialized Variable with weight decay.

    Note that the Variable is initialized with a truncated normal distribution.
    A weight decay is added only if one is specified.

    Args:
      name: name of the variable
      shape: list of ints
      stddev: standard deviation of a truncated Gaussian
      wd: add L2Loss weight decay multiplied by this float. If None, weight
          decay is not added for this Variable.

    Returns:
      Variable Tensor
    """
    dtype = tf.float16 if FLAGS.use_fp16 else tf.float32
    var = _variable_on_cpu(
        name,
        shape,
        tf.truncated_normal_initializer(stddev=stddev, dtype=dtype))
    if wd is not None:
        weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, name='weight_loss')
        tf.add_to_collection('losses', weight_decay)
    return var
cifar10.py 文件源码 项目:MachineLearningTutorial 作者: SpikeKing 项目源码 文件源码 阅读 30 收藏 0 点赞 0 评论 0
def distorted_inputs():
    """Construct distorted input for CIFAR training using the Reader ops.

    Returns:
      images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
      labels: Labels. 1D tensor of [batch_size] size.

    Raises:
      ValueError: If no data_dir
    """
    if not FLAGS.data_dir:
        raise ValueError('Please supply a data_dir')
    data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin')
    images, labels = cifar10_input.distorted_inputs(data_dir=data_dir,
                                                    batch_size=FLAGS.batch_size)
    if FLAGS.use_fp16:
        images = tf.cast(images, tf.float16)
        labels = tf.cast(labels, tf.float16)
    return images, labels
cifar10.py 文件源码 项目:MachineLearningTutorial 作者: SpikeKing 项目源码 文件源码 阅读 31 收藏 0 点赞 0 评论 0
def inputs(eval_data):
    """Construct input for CIFAR evaluation using the Reader ops.

    Args:
      eval_data: bool, indicating if one should use the train or eval data set.

    Returns:
      images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
      labels: Labels. 1D tensor of [batch_size] size.

    Raises:
      ValueError: If no data_dir
    """
    if not FLAGS.data_dir:
        raise ValueError('Please supply a data_dir')
    data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin')
    images, labels = cifar10_input.inputs(eval_data=eval_data,
                                          data_dir=data_dir,
                                          batch_size=FLAGS.batch_size)
    if FLAGS.use_fp16:
        images = tf.cast(images, tf.float16)
        labels = tf.cast(labels, tf.float16)
    return images, labels
tensorflow_backend.py 文件源码 项目:deep-learning-keras-projects 作者: jasmeetsb 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def _convert_string_dtype(dtype):
    if dtype == 'float16':
        return tf.float16
    if dtype == 'float32':
        return tf.float32
    elif dtype == 'float64':
        return tf.float64
    elif dtype == 'int16':
        return tf.int16
    elif dtype == 'int32':
        return tf.int32
    elif dtype == 'int64':
        return tf.int64
    elif dtype == 'uint8':
        return tf.int8
    elif dtype == 'uint16':
        return tf.uint16
    else:
        raise ValueError('Unsupported dtype:', dtype)
translate.py 文件源码 项目:DeepLearning 作者: Wanwannodao 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def create_model(sess, forward_only):
    dtype = tf.float16 if FLAGS.use_fp16 else tf.float32
    model = seq2seq_model.Seq2SeqModel(
        FLAGS.form_vocab_size,
        FLAGS.to_vocab_size,
        _buckets,
        FLAGS.size,
        FLAGS.num_layers,
        FALGS.max_gradinet_norm,
        FLAGS.batch_size,
        FALGS.learning_rate,
        FALGS.learning_rate_decay_factor,
        forward_only=forward_only,
        dtype=dtype)
    ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
    if ckpt and tf.train.checkpoint_exits(ckpt.model_checkpoint_path):
        print("Reading model params from %s" % ckpt.model_checkpoint_path)
        model.saver.restore(sess, ckpt.model_checkpoint_path)
    else:
        print("Created model with fresh params")
        sess.run(tf.global_variables_initializer())
    return model
cifar10.py 文件源码 项目:keras_experiments 作者: avolkov1 项目源码 文件源码 阅读 30 收藏 0 点赞 0 评论 0
def _variable_with_weight_decay(name, shape, stddev, wd):
  """Helper to create an initialized Variable with weight decay.

  Note that the Variable is initialized with a truncated normal distribution.
  A weight decay is added only if one is specified.

  Args:
    name: name of the variable
    shape: list of ints
    stddev: standard deviation of a truncated Gaussian
    wd: add L2Loss weight decay multiplied by this float. If None, weight
        decay is not added for this Variable.

  Returns:
    Variable Tensor
  """
  dtype = tf.float16 if FLAGS.use_fp16 else tf.float32
  var = _variable_on_cpu(
      name,
      shape,
      tf.truncated_normal_initializer(stddev=stddev, dtype=dtype))
  if wd is not None:
    weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, name='weight_loss')
    tf.add_to_collection('losses', weight_decay)
  return var
cifar10.py 文件源码 项目:keras_experiments 作者: avolkov1 项目源码 文件源码 阅读 30 收藏 0 点赞 0 评论 0
def distorted_inputs():
  """Construct distorted input for CIFAR training using the Reader ops.

  Returns:
    images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
    labels: Labels. 1D tensor of [batch_size] size.

  Raises:
    ValueError: If no data_dir
  """
  if not FLAGS.data_dir:
    raise ValueError('Please supply a data_dir')
  data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin')
  images, labels = cifar10_input.distorted_inputs(data_dir=data_dir,
                                                  batch_size=FLAGS.batch_size)
  if FLAGS.use_fp16:
    images = tf.cast(images, tf.float16)
    labels = tf.cast(labels, tf.float16)
  return images, labels
cifar10.py 文件源码 项目:keras_experiments 作者: avolkov1 项目源码 文件源码 阅读 30 收藏 0 点赞 0 评论 0
def inputs(eval_data):
  """Construct input for CIFAR evaluation using the Reader ops.

  Args:
    eval_data: bool, indicating if one should use the train or eval data set.

  Returns:
    images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
    labels: Labels. 1D tensor of [batch_size] size.

  Raises:
    ValueError: If no data_dir
  """
  if not FLAGS.data_dir:
    raise ValueError('Please supply a data_dir')
  data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin')
  images, labels = cifar10_input.inputs(eval_data=eval_data,
                                        data_dir=data_dir,
                                        batch_size=FLAGS.batch_size)
  if FLAGS.use_fp16:
    images = tf.cast(images, tf.float16)
    labels = tf.cast(labels, tf.float16)
  return images, labels
cifar10.py 文件源码 项目:visual-interaction-networks_tensorflow 作者: jaesik817 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def _variable_with_weight_decay(name, shape, stddev, wd):
  """Helper to create an initialized Variable with weight decay.

  Note that the Variable is initialized with a truncated normal distribution.
  A weight decay is added only if one is specified.

  Args:
    name: name of the variable
    shape: list of ints
    stddev: standard deviation of a truncated Gaussian
    wd: add L2Loss weight decay multiplied by this float. If None, weight
        decay is not added for this Variable.

  Returns:
    Variable Tensor
  """
  dtype = tf.float16 if FLAGS.use_fp16 else tf.float32
  var = _variable_on_cpu(
      name,
      shape,
      tf.truncated_normal_initializer(stddev=stddev, dtype=dtype))
  if wd is not None:
    weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, name='weight_loss')
    tf.add_to_collection('losses', weight_decay)
  return var
cifar10.py 文件源码 项目:visual-interaction-networks_tensorflow 作者: jaesik817 项目源码 文件源码 阅读 29 收藏 0 点赞 0 评论 0
def distorted_inputs():
  """Construct distorted input for CIFAR training using the Reader ops.

  Returns:
    images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
    labels: Labels. 1D tensor of [batch_size] size.

  Raises:
    ValueError: If no data_dir
  """
  if not FLAGS.data_dir:
    raise ValueError('Please supply a data_dir')
  data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin')
  images, labels = cifar10_input.distorted_inputs(data_dir=data_dir,
                                                  batch_size=FLAGS.batch_size)
  if FLAGS.use_fp16:
    images = tf.cast(images, tf.float16)
    labels = tf.cast(labels, tf.float16)
  return images, labels


问题


面经


文章

微信
公众号

扫码关注公众号