def _create_predictions(self, decoder_output, features, labels, losses=None):
"""Creates the dictionary of predictions that is returned by the model.
"""
predictions = {}
# Add features and, if available, labels to predictions
predictions.update(_flatten_dict({"features": features}))
if labels is not None:
predictions.update(_flatten_dict({"labels": labels}))
if losses is not None:
predictions["losses"] = _transpose_batch_time(losses)
# Decoders returns output in time-major form [T, B, ...]
# Here we transpose everything back to batch-major for the user
output_dict = collections.OrderedDict(
zip(decoder_output._fields, decoder_output))
decoder_output_flat = _flatten_dict(output_dict)
decoder_output_flat = {
k: _transpose_batch_time(v)
for k, v in decoder_output_flat.items()
}
predictions.update(decoder_output_flat)
# If we predict the ids also map them back into the vocab and process them
if "predicted_ids" in predictions.keys():
vocab_tables = graph_utils.get_dict_from_collection("vocab_tables")
target_id_to_vocab = vocab_tables["target_id_to_vocab"]
predicted_tokens = target_id_to_vocab.lookup(
tf.to_int64(predictions["predicted_ids"]))
# Raw predicted tokens
predictions["predicted_tokens"] = predicted_tokens
return predictions
评论列表
文章目录