video_label_prediction.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:video_labelling_using_youtube8m 作者: LittleWat 项目源码 文件源码
def print_predicted_label(feature, topn=10, latest_checkpoint='./yt8m_model/model.ckpt-2833',
                          id2label_csv='./label_names.csv'):
    id2label_ser = pd.read_csv(id2label_csv, index_col=0)
    id2label = id2label_ser.to_dict()['label_name']

    meta_graph_location = latest_checkpoint + ".meta"

    sess = tf.InteractiveSession()

    saver = tf.train.import_meta_graph(meta_graph_location, clear_devices=True)
    saver.restore(sess, latest_checkpoint)

    input_tensor = tf.get_collection("input_batch_raw")[0]
    num_frames_tensor = tf.get_collection("num_frames")[0]
    predictions_tensor = tf.get_collection("predictions")[0]

    # Workaround for num_epochs issue.
    def set_up_init_ops(variables):
        init_op_list = []
        for variable in list(variables):
            if "train_input" in variable.name:
                init_op_list.append(tf.assign(variable, 1))
                variables.remove(variable)
        init_op_list.append(tf.variables_initializer(variables))
        return init_op_list

    sess.run(set_up_init_ops(tf.get_collection_ref(
            tf.GraphKeys.LOCAL_VARIABLES)))

    padded_feature = np.zeros([300, 1024])
    padded_feature[:feature.shape[0], :] = Dequantize(feature)
    video_batch_val = padded_feature[np.newaxis, :, :].astype(np.float32)
    num_frames_batch_val = np.array([feature.shape[0]], dtype=np.int32)

    predictions_val, = sess.run([predictions_tensor], feed_dict={input_tensor: video_batch_val,
                                                                 num_frames_tensor: num_frames_batch_val})

    predictions_val = predictions_val.flatten()

    top_idxes = np.argsort(predictions_val)[::-1][:topn]

    pprint.pprint([(id2label[x], predictions_val[x]) for x in top_idxes])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号