def after_run(self, _run_context, run_values):
fetches_batch = run_values.results
for fetches in unbatch_dict(fetches_batch):
# Convert to unicode
fetches["predicted_tokens"] = np.char.decode(
fetches["predicted_tokens"].astype("S"), "utf-8")
predicted_tokens = fetches["predicted_tokens"]
# If we're using beam search we take the first beam
if np.ndim(predicted_tokens) > 1:
predicted_tokens = predicted_tokens[:, 0]
fetches["features.source_tokens"] = np.char.decode(
fetches["features.source_tokens"].astype("S"), "utf-8")
source_tokens = fetches["features.source_tokens"]
source_len = fetches["features.source_len"]
if self._unk_replace_fn is not None:
# We slice the attention scores so that we do not
# accidentially replace UNK with a SEQUENCE_END token
attention_scores = fetches["attention_scores"]
attention_scores = attention_scores[:, :source_len - 1]
predicted_tokens = self._unk_replace_fn(
source_tokens=source_tokens,
predicted_tokens=predicted_tokens,
attention_scores=attention_scores)
sent = self.params["delimiter"].join(predicted_tokens).split(
"SEQUENCE_END")[0]
# Apply postproc
if self._postproc_fn:
sent = self._postproc_fn(sent)
sent = sent.strip()
print(sent)
评论列表
文章目录