n02_convert.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:kaggle_yt8m 作者: N01Z3 项目源码 文件源码
def tf2npz(tf_path, export_folder=FAST):
    vid_ids = []
    labels = []
    mean_rgb = []
    mean_audio = []
    tf_basename = os.path.basename(tf_path)
    npz_basename = tf_basename[:-len('.tfrecord')] + '.npz'
    isTrain = '/test' not in tf_path

    for example in tf.python_io.tf_record_iterator(tf_path):
        tf_example = tf.train.Example.FromString(example).features
        vid_ids.append(tf_example.feature['video_id'].bytes_list.value[0].decode(encoding='UTF-8'))
        if isTrain:
            labels.append(np.array(tf_example.feature['labels'].int64_list.value))
        mean_rgb.append(np.array(tf_example.feature['mean_rgb'].float_list.value).astype(np.float32))
        mean_audio.append(np.array(tf_example.feature['mean_audio'].float_list.value).astype(np.float32))

    save_path = export_folder + '/' + npz_basename
    np.savez(save_path,
             rgb=StandardScaler().fit_transform(np.array(mean_rgb)),
             audio=StandardScaler().fit_transform(np.array(mean_audio)),
             ids=np.array(vid_ids),
             labels=labels
             )
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号