def make_data_provider(self, **kwargs):
context_keys_to_features = {
self.params["image_field"]: tf.FixedLenFeature(
[], dtype=tf.string),
"image/format": tf.FixedLenFeature(
[], dtype=tf.string, default_value=self.params["image_format"]),
}
sequence_keys_to_features = {
self.params["caption_ids_field"]: tf.FixedLenSequenceFeature(
[], dtype=tf.int64),
self.params["caption_tokens_field"]: tf.FixedLenSequenceFeature(
[], dtype=tf.string)
}
items_to_handlers = {
"image": tfexample_decoder.Image(
image_key=self.params["image_field"],
format_key="image/format",
channels=3),
"target_ids":
tfexample_decoder.Tensor(self.params["caption_ids_field"]),
"target_tokens":
tfexample_decoder.Tensor(self.params["caption_tokens_field"]),
"target_len": tfexample_decoder.ItemHandlerCallback(
keys=[self.params["caption_tokens_field"]],
func=lambda x: tf.size(x[self.params["caption_tokens_field"]]))
}
decoder = TFSEquenceExampleDecoder(
context_keys_to_features, sequence_keys_to_features, items_to_handlers)
dataset = tf.contrib.slim.dataset.Dataset(
data_sources=self.params["files"],
reader=tf.TFRecordReader,
decoder=decoder,
num_samples=None,
items_to_descriptions={})
return tf.contrib.slim.dataset_data_provider.DatasetDataProvider(
dataset=dataset,
shuffle=self.params["shuffle"],
num_epochs=self.params["num_epochs"],
**kwargs)
评论列表
文章目录