def fit(self, train_data, train_labels, val_data, val_labels):
t_process, t_wall = time.process_time(), time.time()
sess = tf.Session(graph=self.graph)
shutil.rmtree(self._get_path('summaries'), ignore_errors=True)
writer = tf.summary.FileWriter(self._get_path('summaries'), self.graph)
shutil.rmtree(self._get_path('checkpoints'), ignore_errors=True)
os.makedirs(self._get_path('checkpoints'))
path = os.path.join(self._get_path('checkpoints'), 'model')
sess.run(self.op_init)
# Training.
accuracies = []
losses = []
indices = collections.deque()
num_steps = int(self.num_epochs * train_data.shape[0] / self.batch_size)
for step in range(1, num_steps+1):
# Be sure to have used all the samples before using one a second time.
if len(indices) < self.batch_size:
indices.extend(np.random.permutation(train_data.shape[0]))
idx = [indices.popleft() for i in range(self.batch_size)]
batch_data, batch_labels = train_data[idx, :, :, :], train_labels[idx]
if type(batch_data) is not np.ndarray:
batch_data = batch_data.toarray() # convert sparse matrices
feed_dict = {self.ph_data: batch_data, self.ph_labels: batch_labels, self.ph_dropout: self.dropout}
learning_rate, loss_average = sess.run([self.op_train, self.op_loss_average], feed_dict)
# Periodical evaluation of the model.
if step % self.eval_frequency == 0 or step == num_steps:
epoch = step * self.batch_size / train_data.shape[0]
print('step {} / {} (epoch {:.2f} / {}):'.format(step, num_steps, epoch, self.num_epochs))
print(' learning_rate = {:.2e}, loss_average = {:.2e}'.format(learning_rate, loss_average))
string, auc, loss, scores_summary = self.evaluate(train_data, train_labels, sess)
print(' training {}'.format(string))
string, auc, loss, scores_summary = self.evaluate(val_data, val_labels, sess)
print(' validation {}'.format(string))
print(' time: {:.0f}s (wall {:.0f}s)'.format(time.process_time()-t_process, time.time()-t_wall))
accuracies.append(auc)
losses.append(loss)
# Summaries for TensorBoard.
summary = tf.Summary()
summary.ParseFromString(sess.run(self.op_summary, feed_dict))
summary.value.add(tag='validation/auc', simple_value=auc)
summary.value.add(tag='validation/loss', simple_value=loss)
writer.add_summary(summary, step)
# Save model parameters (for evaluation).
self.op_saver.save(sess, path, global_step=step)
print('validation accuracy: peak = {:.2f}, mean = {:.2f}'.format(max(accuracies), np.mean(accuracies[-10:])))
writer.close()
sess.close()
t_step = (time.time() - t_wall) / num_steps
return accuracies, losses, t_step, scores_summary
评论列表
文章目录