learning_test.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def testTrainWithNonDefaultGraph(self):
    self._logdir = os.path.join(self.get_temp_dir(), 'tmp_logs8/')
    g = tf.Graph()
    with g.as_default():
      tf.set_random_seed(0)
      tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
      tf_labels = tf.constant(self._labels, dtype=tf.float32)

      tf_predictions = LogisticClassifier(tf_inputs)
      slim.losses.log_loss(tf_predictions, tf_labels)
      total_loss = slim.losses.get_total_loss()

      optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)

      train_op = slim.learning.create_train_op(total_loss, optimizer)

    loss = slim.learning.train(
        train_op, self._logdir, number_of_steps=300, log_every_n_steps=10,
        graph=g)
    self.assertIsNotNone(loss)
    self.assertLess(loss, .015)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号