tf_regression.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:tensorflow_to_lambda_serverless 作者: jacopotagliabue 项目源码 文件源码
def train(self, train_X, train_Y, learning_rate, training_epochs, model_output_dir=None):
        n_samples = train_X.shape[0]
        # Mean squared error
        cost = tf.reduce_sum(tf.pow(self.model - self.vars['Y'], 2)) / (2 * n_samples)
        # Gradient descent
        optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
        # Launch the graph
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            saver = tf.train.Saver(tf.global_variables())
            # Fit all training data
            for epoch in range(training_epochs):
                for x, y in zip(train_X, train_Y):
                    sess.run(optimizer, feed_dict={self.vars['X']: x, self.vars['Y']: y})
            # Save model locally
            saver.save(sess, model_output_dir + 'model.ckpt')

        return
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号