def prepare_reader(self, filename_queue, batch_size=1024):
"""Creates a single reader thread for pre-aggregated YouTube 8M Examples.
Args:
filename_queue: A tensorflow queue of filename locations.
Returns:
A tuple of video indexes, features, labels, and padding data.
"""
reader = tf.TFRecordReader()
_, serialized_examples = reader.read_up_to(filename_queue, batch_size)
# set the mapping from the fields to data types in the proto
num_features = len(self.feature_names)
assert num_features > 0, "self.feature_names is empty!"
assert len(self.feature_names) == len(self.feature_sizes), \
"length of feature_names (={}) != length of feature_sizes (={})".format( \
len(self.feature_names), len(self.feature_sizes))
feature_map = {"video_id": tf.FixedLenFeature([], tf.string),
"predictions": tf.FixedLenFeature([self.num_classes], tf.float32),
"labels": tf.VarLenFeature(tf.int64)}
for feature_index in range(num_features):
feature_map[self.feature_names[feature_index]] = tf.FixedLenFeature(
[self.feature_sizes[feature_index]], tf.float32)
features = tf.parse_example(serialized_examples, features=feature_map)
labels = tf.sparse_to_indicator(features["labels"], self.num_classes)
labels.set_shape([None, self.num_classes])
concatenated_features = tf.concat([
features[feature_name] for feature_name in self.feature_names], 1)
return features["video_id"], concatenated_features, labels, tf.ones([tf.shape(serialized_examples)[0]]), features["predictions"]
评论列表
文章目录