def decode(self, cell_dec, enc_final_state, output_size, output_embed_matrix, training, grammar_helper=None):
if self.config.use_dot_product_output:
output_layer = DotProductLayer(output_embed_matrix)
else:
output_layer = tf.layers.Dense(output_size, use_bias=False)
go_vector = tf.ones((self.batch_size,), dtype=tf.int32) * self.config.grammar.start
if training:
output_ids_with_go = tf.concat([tf.expand_dims(go_vector, axis=1), self.output_placeholder], axis=1)
outputs = tf.nn.embedding_lookup([output_embed_matrix], output_ids_with_go)
helper = TrainingHelper(outputs, self.output_length_placeholder+1)
else:
helper = GreedyEmbeddingHelper(output_embed_matrix, go_vector, self.config.grammar.end)
if self.config.use_grammar_constraints:
decoder = GrammarBasicDecoder(self.config.grammar, cell_dec, helper, enc_final_state, output_layer=output_layer, training_output = self.output_placeholder if training else None,
grammar_helper=grammar_helper)
else:
decoder = BasicDecoder(cell_dec, helper, enc_final_state, output_layer=output_layer)
final_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder, impute_finished=True, maximum_iterations=self.max_length)
return final_outputs
seq2seq_helpers.py 文件源码
python
阅读 32
收藏 0
点赞 0
评论 0
评论列表
文章目录