def fit(self, X_train, y_train, X_valid, y_valid, X_test, y_test, steps=400):
tf.global_variables_initializer().run()
redirect=FDRedirector(STDERR)
for i in range(steps):
redirect.start()
feed_dict = {self.labels:y_train}
for key, tensor in self.features.items():
feed_dict[tensor] = X_train[key]
predictions, loss = sess.run([self.prediction, self.train_op], feed_dict=feed_dict)
if i % 10 == 0:
print("step:{} loss:{:.3g} np.std(predictions):{:.3g}".format(i, loss, np.std(predictions)))
self.threshold = float(min(self.threshold_from_data(X_valid, y_valid), self.threshold_from_data(X_train, y_train)))
tf.get_collection_ref("threshold")[0] = self.threshold
self.print_metrics(X_train, y_train, "Training")
self.print_metrics(X_valid, y_valid, "Validation")
errors = redirect.stop()
if errors:
print(errors)
self.print_metrics(X_test, y_test, "Test")
评论列表
文章目录