tfrecord_read.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:Youtube8mdataset_kagglechallenge 作者: jasonlee27 项目源码 文件源码
def prepare_reader(self,
                       filename_queue,
                       max_quantized_value=2,
                       min_quantized_value=-2):
        reader = tf.TFRecordReader()
        _, serialized_example = reader.read(filename_queue)
        context_features, sequence_features = {"video_id": tf.FixedLenFeature([], tf.string),
                                               "labels": tf.VarLenFeature(tf.int64)}, None
        if self.sequence_data:
            sequence_features = {self.feature_name[0]: tf.FixedLenSequenceFeature([], dtype=tf.string),
                                 self.feature_name[1]: tf.FixedLenSequenceFeature([], dtype=tf.string), }
        else:
            context_features[self.feature_name[0]] = tf.FixedLenFeature(self.feature_size[0], tf.float32)
            context_features[self.feature_name[1]] = tf.FixedLenFeature(self.feature_size[1], tf.float32)

        contexts, features = tf.parse_single_sequence_example(serialized_example,
                                                              context_features=context_features,
                                                              sequence_features=sequence_features)
        labels = (tf.cast(contexts["labels"].values, tf.int64))

        if self.sequence_data:
            decoded_features = tf.reshape(tf.cast(tf.decode_raw(features[self.feature_name[0]], tf.uint8), tf.float32),
                                          [-1, self.feature_size[0]])
            video_matrix = Dequantize(decoded_features, max_quantized_value, min_quantized_value)

            decoded_features = tf.reshape(tf.cast(tf.decode_raw(features[self.feature_name[1]], tf.uint8), tf.float32),
                                          [-1, self.feature_size[1]])
            audio_matrix = Dequantize(decoded_features, max_quantized_value, min_quantized_value)

            num_frames = tf.minimum(tf.shape(decoded_features)[0], self.max_frames)
        else:
            video_matrix = contexts[self.feature_name[0]]
            audio_matrix = contexts[self.feature_name[1]]
            num_frames = tf.constant(-1)

        # Pad or truncate to 'max_frames' frames.
        # video_matrix = resize_axis(video_matrix, 0, self.max_frames)
        return contexts["video_id"], video_matrix, audio_matrix, labels, num_frames
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号