def fit(self, X_train, y_train, ridge=1.0):
self._reset()
X_train, y_train = self.check_X_y(X_train, y_train)
self.X_train = np.float32(X_train)
self.y_train = np.float32(y_train)
sample_size = self.X_train.shape[0]
if np.isscalar(ridge):
ridge = np.ones(sample_size) * ridge
assert ridge.ndim == 1
X_dists = np.zeros((sample_size, sample_size), dtype=np.float32)
with tf.Session(graph=self.graph, config=tf.ConfigProto(
intra_op_parallelism_threads=self.NUM_THREADS)) as sess:
dist_op = self.ops['dist_op']
v1, v2 = self.vars['v1_h'], self.vars['v2_h']
for i in range(sample_size):
X_dists[i] = sess.run(dist_op, feed_dict={v1:self.X_train[i], v2:self.X_train})
K_ridge_op = self.ops['K_ridge_op']
X_dists_ph = self.vars['X_dists_h']
ridge_ph = self.vars['ridge_h']
self.K = sess.run(K_ridge_op, feed_dict={X_dists_ph:X_dists, ridge_ph:ridge})
K_ph = self.vars['K_h']
K_inv_op = self.ops['K_inv_op']
self.K_inv = sess.run(K_inv_op, feed_dict={K_ph:self.K})
xy_op = self.ops['xy_op']
K_inv_ph = self.vars['K_inv_h']
yt_ph = self.vars['yt_h']
self.xy_ = sess.run(xy_op, feed_dict={K_inv_ph:self.K_inv,
yt_ph:self.y_train})
return self
评论列表
文章目录