def make_data_provider(self, **kwargs):
"""Creates DataProvider instance for this input pipeline. Additional keyword arguments
are passed to the DataProvider.
"""
context_keys_to_features = {
self.image_field: tf.FixedLenFeature(
[], dtype=tf.string),
"image/format": tf.FixedLenFeature(
[], dtype=tf.string, default_value=self.image_format),
}
sequence_keys_to_features = {
self.caption_ids_field: tf.FixedLenSequenceFeature(
[], dtype=tf.int64),
self.caption_tokens_field: tf.FixedLenSequenceFeature(
[], dtype=tf.string)
}
items_to_handlers = {
'image': tfslim.tfexample_decoder.Image(
image_key=self.image_field,
format_key="image/format",
channels=3),
'target_ids': tfslim.tfexample_decoder.Tensor(self.caption_ids_field),
'target_token': tfslim.tfexample_decoder.Tensor(self.caption_tokens_field),
'target_len': tfslim.tfexample_decoder.ItemHandlerCallback(
keys=[self.caption_tokens_field],
func=lambda x: tf.size(x[self.caption_tokens_field]))
}
decoder = TFSequenceExampleDecoder(
context_keys_to_features, sequence_keys_to_features, items_to_handlers)
dataset = Dataset(
data_sources=self.files,
reader=tf.TFRecordReader,
decoder=decoder,
num_samples=None,
items_to_descriptions={})
return DatasetDataProvider(
dataset=dataset,
shuffle=self.shuffle,
num_epochs=self.num_epochs,
**kwargs)
评论列表
文章目录