basic_test.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:sonnet 作者: deepmind 项目源码 文件源码
def testGradientColocation(self):
    """Tests a particular device (e.g. gpu, cpu) placement.

    This test ensures that the following device placement is possible:

    * The Linear module is on the gpu,
    * the optimizer is declared to be on the cpu,
    * but when calling minimize on the optimizer, we pass True to
      colocate_gradients_with_ops.

    The test exists because while one may expect tf.matmul(X, w) + b to be
    equivalent to tf.nn.xw_plus_b(X, w, b), with the latter this placement
    results in an InvalidArgumentError.

    Warning: if there is no gpu available to tensorflow this test will be
    skipped with just a warning! This is because the test requires that
    tensorflow has access to a gpu, but often this is not the case.
    """
    if not any(x.device_type == "GPU" for x in device_lib.list_local_devices()):
      tf.logging.warn("Skipping the gradient colocation test as there is no "
                      "gpu available to tensorflow.")
      return
    n_outputs = 5
    n_inputs = 3
    batch_size = 7
    linear = snt.Linear(n_outputs)
    with tf.device("/cpu:*"):
      # Set up data.
      inputs = tf.placeholder(tf.float32, [batch_size, n_inputs])
      labels = tf.to_int64(np.ones((batch_size)))
      # Predictions.
      with tf.device("/gpu:*"):
        outputs = linear(inputs)
      # Calculate the loss.
      cross_entropy = tf.contrib.nn.deprecated_flipped_sparse_softmax_cross_entropy_with_logits(  # pylint: disable=line-too-long
          outputs, labels, name="xentropy")
      loss = tf.reduce_mean(cross_entropy, name="xentropy_mean")
      # Optimizer.
      optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1)
      optimizer.minimize(loss, colocate_gradients_with_ops=True)
    init = tf.global_variables_initializer()
    try:
      with self.test_session(force_gpu=True) as sess:
        sess.run(init)
    except tf.errors.InvalidArgumentError as e:
      self.fail("Cannot start the session. Details:\n" + e.message)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号