def make_data_provider(self, **kwargs):
decoder_source = split_tokens_decoder.SplitTokensDecoder(
tokens_feature_name="source_tokens",
length_feature_name="source_len",
append_token="SEQUENCE_END",
delimiter=self.params["source_delimiter"])
dataset_source = tf.contrib.slim.dataset.Dataset(
data_sources=self.params["source_files"],
reader=tf.TextLineReader,
decoder=decoder_source,
num_samples=None,
items_to_descriptions={})
dataset_target = None
if len(self.params["target_files"]) > 0:
decoder_target = split_tokens_decoder.SplitTokensDecoder(
tokens_feature_name="target_tokens",
length_feature_name="target_len",
prepend_token="SEQUENCE_START",
append_token="SEQUENCE_END",
delimiter=self.params["target_delimiter"])
dataset_target = tf.contrib.slim.dataset.Dataset(
data_sources=self.params["target_files"],
reader=tf.TextLineReader,
decoder=decoder_target,
num_samples=None,
items_to_descriptions={})
return parallel_data_provider.ParallelDataProvider(
dataset1=dataset_source,
dataset2=dataset_target,
shuffle=self.params["shuffle"],
num_epochs=self.params["num_epochs"],
**kwargs)
评论列表
文章目录