def get_evaluation_input(inputs, params):
with tf.device("/cpu:0"):
# Create datasets
datasets = []
for data in inputs:
dataset = tf.data.Dataset.from_tensor_slices(data)
# Split string
dataset = dataset.map(lambda x: tf.string_split([x]).values,
num_parallel_calls=params.num_threads)
# Append <eos>
dataset = dataset.map(
lambda x: tf.concat([x, [tf.constant(params.eos)]], axis=0),
num_parallel_calls=params.num_threads
)
datasets.append(dataset)
dataset = tf.data.Dataset.zip(tuple(datasets))
# Convert tuple to dictionary
dataset = dataset.map(
lambda *x: {
"source": x[0],
"source_length": tf.shape(x[0])[0],
"references": x[1:]
},
num_parallel_calls=params.num_threads
)
dataset = dataset.padded_batch(
params.eval_batch_size,
{
"source": [tf.Dimension(None)],
"source_length": [],
"references": (tf.Dimension(None),) * (len(inputs) - 1)
},
{
"source": params.pad,
"source_length": 0,
"references": (params.pad,) * (len(inputs) - 1)
}
)
iterator = dataset.make_one_shot_iterator()
features = iterator.get_next()
src_table = tf.contrib.lookup.index_table_from_tensor(
tf.constant(params.vocabulary["source"]),
default_value=params.mapping["source"][params.unk]
)
tgt_table = tf.contrib.lookup.index_table_from_tensor(
tf.constant(params.vocabulary["target"]),
default_value=params.mapping["target"][params.unk]
)
features["source"] = src_table.lookup(features["source"])
features["references"] = tuple(
tgt_table.lookup(item) for item in features["references"]
)
return features
评论列表
文章目录