def get_inference_input(inputs, params):
dataset = tf.data.Dataset.from_tensor_slices(
tf.constant(inputs)
)
# Split string
dataset = dataset.map(lambda x: tf.string_split([x]).values,
num_parallel_calls=params.num_threads)
# Append <eos>
dataset = dataset.map(
lambda x: tf.concat([x, [tf.constant(params.eos)]], axis=0),
num_parallel_calls=params.num_threads
)
# Convert tuple to dictionary
dataset = dataset.map(
lambda x: {"source": x, "source_length": tf.shape(x)[0]},
num_parallel_calls=params.num_threads
)
dataset = dataset.padded_batch(
params.decode_batch_size,
{"source": [tf.Dimension(None)], "source_length": []},
{"source": params.pad, "source_length": 0}
)
iterator = dataset.make_one_shot_iterator()
features = iterator.get_next()
src_table = tf.contrib.lookup.index_table_from_tensor(
tf.constant(params.vocabulary["source"]),
default_value=params.mapping["source"][params.unk]
)
features["source"] = src_table.lookup(features["source"])
return features
评论列表
文章目录