def train(self):
trainS, trainQ, trainA = vectorize_data(
self.trainData, self.word_idx, self.sentence_size, self.batch_size, self.n_cand, self.memory_size)
valS, valQ, valA = vectorize_data(
self.valData, self.word_idx, self.sentence_size, self.batch_size, self.n_cand, self.memory_size)
n_train = len(trainS)
n_val = len(valS)
print("Training Size", n_train)
print("Validation Size", n_val)
tf.set_random_seed(self.random_state)
batches = zip(range(0, n_train - self.batch_size, self.batch_size),
range(self.batch_size, n_train, self.batch_size))
batches = [(start, end) for start, end in batches]
best_validation_accuracy = 0
for t in range(1, self.epochs + 1):
np.random.shuffle(batches)
total_cost = 0.0
for start, end in batches:
s = trainS[start:end]
q = trainQ[start:end]
a = trainA[start:end]
cost_t = self.model.batch_fit(s, q, a)
total_cost += cost_t
if t % self.evaluation_interval == 0:
train_preds = self.batch_predict(trainS, trainQ, n_train)
val_preds = self.batch_predict(valS, valQ, n_val)
train_acc = metrics.accuracy_score(
np.array(train_preds), trainA)
val_acc = metrics.accuracy_score(val_preds, valA)
print('-----------------------')
print('Epoch', t)
print('Total Cost:', total_cost)
print('Training Accuracy:', train_acc)
print('Validation Accuracy:', val_acc)
print('-----------------------')
# write summary
train_acc_summary = tf.summary.scalar(
'task_' + str(self.task_id) + '/' + 'train_acc', tf.constant((train_acc), dtype=tf.float32))
val_acc_summary = tf.summary.scalar(
'task_' + str(self.task_id) + '/' + 'val_acc', tf.constant((val_acc), dtype=tf.float32))
merged_summary = tf.summary.merge(
[train_acc_summary, val_acc_summary])
summary_str = self.sess.run(merged_summary)
self.summary_writer.add_summary(summary_str, t)
self.summary_writer.flush()
if val_acc > best_validation_accuracy:
best_validation_accuracy = val_acc
self.saver.save(self.sess, self.model_dir +
'model.ckpt', global_step=t)
评论列表
文章目录