album_genre_api.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:wiki-album-genre 作者: aliostad 项目源码 文件源码
def get_genre():

  try:

    with graph.as_default():

      session_conf = tf.ConfigProto(
          allow_soft_placement=FLAGS.allow_soft_placement,
          log_device_placement=FLAGS.log_device_placement)
      sess = tf.Session(config=session_conf)
      with sess.as_default():
        # Load the saved meta graph and restore variables
        saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
        saver.restore(sess, checkpoint_file)

        albums = request.args.get('albums')
        x_raw = albums.split(',')
        all_predictions = []
        x_test = np.array(list(vocab_processor.transform(x_raw)))

        # Get the placeholders from the graph by name
        input_x = graph.get_operation_by_name("input_x").outputs[0]
        dropout_keep_prob = graph.get_operation_by_name("dropout_keep_prob").outputs[0]

        # Tensors we want to evaluate
        predictions = graph.get_operation_by_name("output/predictions").outputs[0]

        # Generate batches for one epoch
        batches = data_loader.batch_iter(list(x_test), FLAGS.batch_size, 1, shuffle=False)

        for x_test_batch in batches:
            batch_predictions = sess.run(predictions, {input_x: x_test_batch, dropout_keep_prob: 1.0})
            all_predictions = np.concatenate([all_predictions, batch_predictions])

        return jsonify({'results': map(lambda x: data_loader.genre_ids[int(x)], all_predictions)})

  except Exception as e:
    print e
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号