def train_model(model, batch_gen, num_train_steps, weights_fld):
saver = tf.train.Saver(
) # defaults to saving all variables - in this case embed_matrix, nce_weight, nce_bias
initial_step = 0
utils.make_dir('checkpoints')
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
ckpt = tf.train.get_checkpoint_state(
os.path.dirname('checkpoints/checkpoint'))
# if that checkpoint exists, restore from checkpoint
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
total_loss = 0.0 # we use this to calculate late average loss in the last SKIP_STEP steps
writer = tf.summary.FileWriter(
'improved_graph/lr' + str(LEARNING_RATE), sess.graph)
initial_step = model.global_step.eval()
for index in range(initial_step, initial_step + num_train_steps):
centers, targets = next(batch_gen)
feed_dict = {
model.center_words: centers,
model.target_words: targets
}
loss_batch, _, summary = sess.run(
[model.loss, model.optimizer, model.summary_op],
feed_dict=feed_dict)
writer.add_summary(summary, global_step=index)
total_loss += loss_batch
if (index + 1) % SKIP_STEP == 0:
print('Average loss at step {}: {:5.1f}'.format(
index, total_loss / SKIP_STEP))
total_loss = 0.0
saver.save(sess, 'checkpoints/skip-gram', index)
####################
# code to visualize the embeddings. uncomment the below to visualize embeddings
# run "'tensorboard --logdir='processed'" to see the embeddings
final_embed_matrix = sess.run(model.embed_matrix)
# # it has to variable. constants don't work here. you can't reuse model.embed_matrix
embedding_var = tf.Variable(
final_embed_matrix[:1000], name='embedding')
sess.run(embedding_var.initializer)
config = projector.ProjectorConfig()
summary_writer = tf.summary.FileWriter('processed')
# # add embedding to the config file
embedding = config.embeddings.add()
embedding.tensor_name = embedding_var.name
# # link this tensor to its metadata file, in this case the first 500 words of vocab
embedding.metadata_path = 'processed/vocab_1000.tsv'
# # saves a configuration file that TensorBoard will read during startup.
projector.visualize_embeddings(summary_writer, config)
saver_embed = tf.train.Saver([embedding_var])
saver_embed.save(sess, 'processed/model3.ckpt', 1)
word2vec_visualize.py 文件源码
python
阅读 25
收藏 0
点赞 0
评论 0
评论列表
文章目录