def train(self, t_x, t_y, v_x, v_y, lrv, char2idx, sess, epochs, batch_size=10):
idx2char = {k: v for v, k in char2idx.items()}
v_y_g = [np.trim_zeros(v_y_t) for v_y_t in v_y]
gold_out = [toolbox.generate_trans_out(v_y_t, idx2char) for v_y_t in v_y_g]
best_score = 0
for epoch in range(epochs):
Batch.train_seq2seq(sess, model=self.en_vec + self.trans_labels, decoding=self.feed_previous, batch_size=batch_size,
config=self.trans_train, lr=self.trans_l_rate, lrv=lrv, data=[t_x] + [t_y])
pred = Batch.predict_seq2seq(sess, model=self.en_vec + self.de_vec + self.trans_output, decoding=self.feed_previous,
decode_len=self.decode_step, data=[v_x], argmax=True, batch_size=100)
pred_out = [toolbox.generate_trans_out(pre_t, idx2char) for pre_t in pred]
c_scores = evaluation.trans_evaluator(gold_out, pred_out)
print 'epoch: %d' % (epoch + 1)
print 'ACC: %f' % c_scores[0]
print 'Token F score: %f' % c_scores[1]
if c_scores[1] > best_score:
best_score = c_scores[1]
self.saver.save(sess, self.trained + '_weights', write_meta_graph=False)
if best_score > 0:
self.saver.restore(sess, self.trained + '_weights')
评论列表
文章目录