test_dbinterface.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:tfutils 作者: neuroailab 项目源码 文件源码
def train_model(self, num_steps=100):
        x_train = [1, 2, 3, 4]
        y_train = [0, -1, -2, -3]
        x = tf.get_default_graph().get_tensor_by_name('model_0/x:0')
        y = tf.get_default_graph().get_tensor_by_name('model_0/y:0')
        feed_dict = {x: x_train, y: y_train}

        pre_global_step = self.sess.run(self.global_step)
        for step in range(num_steps):
            train_res = self.sess.run(self.train_targets, feed_dict=feed_dict)
            self.log.info('Step: {}, loss: {}'.format(step, train_res['loss']))

        post_global_step = self.sess.run(self.global_step)
        self.assertEqual(pre_global_step + num_steps, post_global_step)
        self.step += num_steps
        return train_res
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号