def fit(self,
features: np.ndarray,
labels: np.ndarray,
quiet=False):
# generic parameter checks
super().fit(features, labels)
self._num_labels = len(np.unique(labels))
graph = tf.Graph()
with graph.as_default():
tf_inputs = tf.Variable(initial_value=features, trainable=False, dtype=tf.float32)
tf_labels = tf.Variable(initial_value=labels, trainable=False, dtype=tf.int32)
if self._shuffle_training:
tf_inputs = tf.random_shuffle(tf_inputs, seed=42)
tf_labels = tf.random_shuffle(tf_labels, seed=42)
with tf.variable_scope("mlp"):
tf_logits = self._model.inference(tf_inputs, self._keep_prob, self._num_labels)
tf_loss = self._model.loss(tf_logits, tf_labels)
tf_train_op = self._model.optimize(tf_loss, self._learning_rate)
tf_init_op = tf.global_variables_initializer()
tf_saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="mlp"))
session = tf.Session(graph=graph)
session.run(tf_init_op)
for epoch in range(self._num_epochs):
session.run(tf_train_op)
# timestamped model file
self._latest_checkpoint = self._checkpoint_dir / "model_{:%Y%m%d%H%M%S%f}".format(datetime.datetime.now())
tf_saver.save(session, str(self._latest_checkpoint), write_meta_graph=False)
session.close()
评论列表
文章目录