def train_model(model, batch_gen, num_train_steps, weights_fld):
saver = tf.train.Saver() # defaults to saving all variables - in this case embed_matrix, nce_weight, nce_bias
initial_step = 0
utils.make_dir('checkpoints')
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
ckpt = tf.train.get_checkpoint_state(os.path.dirname('checkpoints/checkpoint'))
# if that checkpoint exists, restore from checkpoint
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
total_loss = 0.0 # we use this to calculate late average loss in the last SKIP_STEP steps
writer = tf.summary.FileWriter('improved_graph/lr' + str(LEARNING_RATE), sess.graph)
initial_step = model.global_step.eval()
for index in range(initial_step, initial_step + num_train_steps):
centers, targets = next(batch_gen)
feed_dict={model.center_words: centers, model.target_words: targets}
loss_batch, _, summary = sess.run([model.loss, model.optimizer, model.summary_op],
feed_dict=feed_dict)
writer.add_summary(summary, global_step=index)
total_loss += loss_batch
if (index + 1) % SKIP_STEP == 0:
print('Average loss at step {}: {:5.1f}'.format(index, total_loss / SKIP_STEP))
total_loss = 0.0
saver.save(sess, 'checkpoints/skip-gram', index)
####################
# code to visualize the embeddings. uncomment the below to visualize embeddings
# run "'tensorboard --logdir='processed'" to see the embeddings
# final_embed_matrix = sess.run(model.embed_matrix)
# # it has to variable. constants don't work here. you can't reuse model.embed_matrix
# embedding_var = tf.Variable(final_embed_matrix[:1000], name='embedding')
# sess.run(embedding_var.initializer)
# config = projector.ProjectorConfig()
# summary_writer = tf.summary.FileWriter('processed')
# # add embedding to the config file
# embedding = config.embeddings.add()
# embedding.tensor_name = embedding_var.name
# # link this tensor to its metadata file, in this case the first 500 words of vocab
# embedding.metadata_path = 'processed/vocab_1000.tsv'
# # saves a configuration file that TensorBoard will read during startup.
# projector.visualize_embeddings(summary_writer, config)
# saver_embed = tf.train.Saver([embedding_var])
# saver_embed.save(sess, 'processed/model3.ckpt', 1)
评论列表
文章目录