def testTrainWithSessionConfig(self):
with tf.Graph().as_default():
tf.set_random_seed(0)
tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
tf_labels = tf.constant(self._labels, dtype=tf.float32)
tf_predictions = LogisticClassifier(tf_inputs)
slim.losses.log_loss(tf_predictions, tf_labels)
total_loss = slim.losses.get_total_loss()
optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
train_op = slim.learning.create_train_op(total_loss, optimizer)
session_config = tf.ConfigProto(allow_soft_placement=True)
loss = slim.learning.train(
train_op,
None,
number_of_steps=300,
log_every_n_steps=10,
session_config=session_config)
self.assertIsNotNone(loss)
self.assertLess(loss, .015)
评论列表
文章目录