parametric_GP.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:ParametricGP-in-Python 作者: maziarraissi 项目源码 文件源码
def train(self):
        print("Total number of parameters: %d" % (self.hyp.shape[0]))

        X_tf = tf.placeholder(tf.float64)
        y_tf = tf.placeholder(tf.float64)
        hyp_tf = tf.Variable(self.hyp, dtype=tf.float64)

        train = self.likelihood(hyp_tf, X_tf, y_tf)

        init = tf.global_variables_initializer()
        self.sess.run(init)

        start_time = timeit.default_timer()
        for i in range(1,self.max_iter+1):
            # Fetch minibatch
            X_batch, y_batch = fetch_minibatch(self.X,self.y,self.N_batch)
            self.sess.run(train, {X_tf:X_batch, y_tf:y_batch})

            if i % self.monitor_likelihood == 0:
                elapsed = timeit.default_timer() - start_time
                nlml = self.sess.run(self.nlml)
                print('Iteration: %d, NLML: %.2f, Time: %.2f' % (i, nlml, elapsed))
                start_time = timeit.default_timer()

        self.hyp = self.sess.run(hyp_tf)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号