inference_embedding.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:youtube-8m 作者: wangheda 项目源码 文件源码
def inference(video_id_batch, prediction_batch, label_batch, saver, out_file_location):
    global_step_val = -1
    with tf.Session() as sess:
        if FLAGS.model_checkpoint_path:
            checkpoint = FLAGS.model_checkpoint_path
        else:
            checkpoint = tf.train.latest_checkpoint(FLAGS.train_dir)
        if checkpoint:
            logging.info("Loading checkpoint for eval: " + checkpoint)
            # Restores from checkpoint
            saver.restore(sess, checkpoint)
            # Assuming model_checkpoint_path looks something like:
            # /my-favorite-path/yt8m_train/model.ckpt-0, extract global_step from it.
            global_step_val = checkpoint.split("/")[-1].split("-")[-1]
        else:
            logging.info("No checkpoint file found.")
            return global_step_val

        sess.run([tf.local_variables_initializer()])

        # Workaround for num_epochs issue.
        def set_up_init_ops(variables):
            init_op_list = []
            for variable in list(variables):
                if "train_input" in variable.name:
                    init_op_list.append(tf.assign(variable, 1))
                    variables.remove(variable)
            init_op_list.append(tf.variables_initializer(variables))
            return init_op_list

        sess.run(set_up_init_ops(tf.get_collection_ref(
            tf.GraphKeys.LOCAL_VARIABLES)))

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        input_indices = np.eye(4716)
        try:
            print("start saving parameters")
            predictions = sess.run(prediction_batch, feed_dict={label_batch: input_indices})
            np.savetxt(out_file_location, predictions)
        except tf.errors.OutOfRangeError:
            logging.info('Done with inference. The output file was written to ' + out_file_location)
        finally:
            coord.request_stop()
        coord.join(threads)
        sess.close()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号