02_word2vec_visualize.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:tf_oreilly 作者: chiphuyen 项目源码 文件源码
def train_model(model, batch_gen, num_train_steps, weights_fld):
    saver = tf.train.Saver() # defaults to saving all variables - in this case embed_matrix, nce_weight, nce_bias

    initial_step = 0
    utils.make_dir('checkpoints')
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        ckpt = tf.train.get_checkpoint_state(os.path.dirname('checkpoints/checkpoint'))
        # if that checkpoint exists, restore from checkpoint
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)

        total_loss = 0.0 # we use this to calculate late average loss in the last SKIP_STEP steps
        writer = tf.summary.FileWriter('improved_graph/lr' + str(LEARNING_RATE), sess.graph)
        initial_step = model.global_step.eval()
        for index in range(initial_step, initial_step + num_train_steps):
            centers, targets = next(batch_gen)
            feed_dict={model.center_words: centers, model.target_words: targets}
            loss_batch, _, summary = sess.run([model.loss, model.optimizer, model.summary_op], 
                                              feed_dict=feed_dict)
            writer.add_summary(summary, global_step=index)
            total_loss += loss_batch
            if (index + 1) % SKIP_STEP == 0:
                print('Average loss at step {}: {:5.1f}'.format(index, total_loss / SKIP_STEP))
                total_loss = 0.0
                saver.save(sess, 'checkpoints/skip-gram', index)

        ####################
        # code to visualize the embeddings. uncomment the below to visualize embeddings
        # run "'tensorboard --logdir='processed'" to see the embeddings
        # final_embed_matrix = sess.run(model.embed_matrix)

        # # it has to variable. constants don't work here. you can't reuse model.embed_matrix
        # embedding_var = tf.Variable(final_embed_matrix[:1000], name='embedding')
        # sess.run(embedding_var.initializer)

        # config = projector.ProjectorConfig()
        # summary_writer = tf.summary.FileWriter('processed')

        # # add embedding to the config file
        # embedding = config.embeddings.add()
        # embedding.tensor_name = embedding_var.name

        # # link this tensor to its metadata file, in this case the first 500 words of vocab
        # embedding.metadata_path = 'processed/vocab_1000.tsv'

        # # saves a configuration file that TensorBoard will read during startup.
        # projector.visualize_embeddings(summary_writer, config)
        # saver_embed = tf.train.Saver([embedding_var])
        # saver_embed.save(sess, 'processed/model3.ckpt', 1)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号