generate_predictions.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:DL2W 作者: gauravmm 项目源码 文件源码
def generate_predictions(tfrecord_file,
                         train_dir,
                         predictions_file,
                         features_file,
                         batch_size,
                         num_k):
    ids, vectors, _ = data_loader.inputs([tfrecord_file], batch_size=batch_size,
                                         num_threads=16, capacity=batch_size*4,
                                         num_epochs=1, is_training=False)

    predictions = model.inference(vectors)
    features = tf.get_default_graph().get_tensor_by_name('fc1/relu:0')

    init_op = tf.local_variables_initializer()
    saver = tf.train.Saver()

    with tf.Session() as sess:
        sess.run(init_op)
        saver.restore(sess, tf.train.latest_checkpoint(train_dir))

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        with open(predictions_file, 'w') as f1, open(features_file, 'w') as f2:
            f1.write('VideoId,LabelConfidencePairs\n')

            while True:
                try:
                    ids_out, predictions_out = sess.run(
                        [ids, predictions])
                except tf.errors.OutOfRangeError:
                    break

                for i, _ in enumerate(ids_out):
                    f1.write(ids_out[i].decode())
                    f1.write(',')
                    top_k = np.argsort(predictions_out[i])[::-1][:num_k]
                    for j in top_k:
                        f1.write('{} {:5f} '.format(j, predictions_out[i][j]))
                    f1.write('\n')

                    #f2.write(ids_out[i].decode())
                    #f2.write(',')
                    #for j in range(len(features_out[i]) - 1):
                    #    f2.write('{:6e},'.format(features_out[i][j]))
                    #f2.write('{:6e}'.format(features_out[i][-1]))
                    #f2.write('\n')

        coord.request_stop()
        coord.join(threads)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号