def _create_decoder(self, encoder_output, features, _labels):
attention_class = locate(self.params["attention.class"]) or \
getattr(decoders.attention, self.params["attention.class"])
attention_layer = attention_class(
params=self.params["attention.params"], mode=self.mode)
# If the input sequence is reversed we also need to reverse
# the attention scores.
reverse_scores_lengths = None
if self.params["source.reverse"]:
reverse_scores_lengths = features["source_len"]
if self.use_beam_search:
reverse_scores_lengths = tf.tile(
input=reverse_scores_lengths,
multiples=[self.params["inference.beam_search.beam_width"]])
return self.decoder_class(
params=self.params["decoder.params"],
mode=self.mode,
vocab_size=self.target_vocab_info.total_size,
attention_values=encoder_output.attention_values,
attention_values_length=encoder_output.attention_values_length,
attention_keys=encoder_output.outputs,
attention_fn=attention_layer,
reverse_scores_lengths=reverse_scores_lengths)
评论列表
文章目录