h5_to_tf.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:DMNN 作者: magnux 项目源码 文件源码
def export_to_tf(self):
    def make_example(key_idx, subject, action, pose, plen):
        ex = tf.train.Example()
        ex.features.feature["key_idx"].int64_list.value.append(int(key_idx))
        ex.features.feature["subject"].int64_list.value.append(int(subject))
        ex.features.feature["action"].int64_list.value.append(int(action))
        ex.features.feature["plen"].int64_list.value.append(int(plen))
        for sublist in pose.tolist():
            for subsublist in sublist:
                for value in subsublist:
                    ex.features.feature["pose"].float_list.value.append(value)
        return ex

    def write_split(is_training, keys):
        writer = None
        shard = 0
        splitname = 'train' if is_training else 'val'
        print('Transforming "%s" split...' % splitname)
        t = trange(len(keys), dynamic_ncols=True)
        for k in t:
            if writer == None:
                writer = tf.python_io.TFRecordWriter(
                    os.path.join(self.data_path, self.data_set + '_' + splitname + '_shard' + str(shard) + '.tf')
                )
            key_idx, subject, action, pose, plen = self.read_h5_data(k, is_training)
            ex = make_example(key_idx, subject, action, pose, plen)
            writer.write(ex.SerializeToString())
            if ((k + 1) % 4096) == 0:
                writer.close()
                writer = None
                shard += 1
        if writer != None:
            writer.close()

    write_split(True, self.train_keys)
    write_split(False, self.val_keys)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号