def main(_):
print('loading word embeddings from %s' % FLAGS.embedding_file)
weight_matrix, word_idx = sentiment.load_embeddings(FLAGS.embedding_file)
train_file = os.path.join(FLAGS.tree_dir, 'train.txt')
print('loading training trees from %s' % train_file)
train_trees = sentiment.load_trees(train_file)
dev_file = os.path.join(FLAGS.tree_dir, 'dev.txt')
print('loading dev trees from %s' % dev_file)
dev_trees = sentiment.load_trees(dev_file)
with tf.Session() as sess:
print('creating the model')
keep_prob = tf.placeholder_with_default(1.0, [])
train_feed_dict = {keep_prob: FLAGS.keep_prob}
word_embedding = sentiment.create_embedding(weight_matrix)
compiler, metrics = sentiment.create_model(
word_embedding, word_idx, FLAGS.lstm_num_units, keep_prob)
loss = tf.reduce_sum(compiler.metric_tensors['all_loss'])
opt = tf.train.AdagradOptimizer(FLAGS.learning_rate)
grads_and_vars = opt.compute_gradients(loss)
found = 0
for i, (grad, var) in enumerate(grads_and_vars):
if var == word_embedding.weights:
found += 1
grad = tf.scalar_mul(FLAGS.embedding_learning_rate_factor, grad)
grads_and_vars[i] = (grad, var)
assert found == 1 # internal consistency check
train = opt.apply_gradients(grads_and_vars)
saver = tf.train.Saver()
print('initializing tensorflow')
sess.run(tf.global_variables_initializer())
with compiler.multiprocessing_pool():
print('training the model')
train_set = compiler.build_loom_inputs(train_trees)
dev_feed_dict = compiler.build_feed_dict(dev_trees)
dev_hits_best = 0.0
for epoch, shuffled in enumerate(td.epochs(train_set, FLAGS.epochs), 1):
train_loss = 0.0
for batch in td.group_by_batches(shuffled, FLAGS.batch_size):
train_feed_dict[compiler.loom_input_tensor] = batch
_, batch_loss = sess.run([train, loss], train_feed_dict)
train_loss += batch_loss
dev_metrics = sess.run(metrics, dev_feed_dict)
dev_loss = dev_metrics['all_loss']
dev_accuracy = ['%s: %.2f' % (k, v * 100) for k, v in
sorted(dev_metrics.items()) if k.endswith('hits')]
print('epoch:%4d, train_loss: %.3e, dev_loss: %.3e, dev_accuracy: [%s]'
% (epoch, train_loss, dev_loss, ' '.join(dev_accuracy)))
dev_hits = dev_metrics['root_hits']
if dev_hits > dev_hits_best:
dev_hits_best = dev_hits
save_path = saver.save(sess, FLAGS.checkpoint_base, global_step=epoch)
print('model saved in file: %s' % save_path)
评论列表
文章目录