def train_model(self, num_steps=100):
x_train = [1, 2, 3, 4]
y_train = [0, -1, -2, -3]
x = tf.get_default_graph().get_tensor_by_name('model_0/x:0')
y = tf.get_default_graph().get_tensor_by_name('model_0/y:0')
feed_dict = {x: x_train, y: y_train}
pre_global_step = self.sess.run(self.global_step)
for step in range(num_steps):
train_res = self.sess.run(self.train_targets, feed_dict=feed_dict)
self.log.info('Step: {}, loss: {}'.format(step, train_res['loss']))
post_global_step = self.sess.run(self.global_step)
self.assertEqual(pre_global_step + num_steps, post_global_step)
self.step += num_steps
return train_res
评论列表
文章目录