def add_decoder_op(self, enc_final_state, enc_hidden_states, output_embed_matrix, training):
cell_dec = tf.contrib.rnn.MultiRNNCell([self.make_rnn_cell(i, for_decoder=True) for i in range(self.config.rnn_layers)])
encoder_hidden_size = int(enc_hidden_states.get_shape()[-1])
decoder_hidden_size = int(cell_dec.output_size)
# if encoder and decoder have different sizes, add a projection layer
if encoder_hidden_size != decoder_hidden_size:
assert False, (encoder_hidden_size, decoder_hidden_size)
with tf.variable_scope('hidden_projection'):
kernel = tf.get_variable('kernel', (encoder_hidden_size, decoder_hidden_size), dtype=tf.float32)
# apply a relu to the projection for good measure
enc_final_state = nest.map_structure(lambda x: tf.nn.relu(tf.matmul(x, kernel)), enc_final_state)
enc_hidden_states = tf.nn.relu(tf.tensordot(enc_hidden_states, kernel, [[2], [1]]))
else:
# flatten and repack the state
enc_final_state = nest.pack_sequence_as(cell_dec.state_size, nest.flatten(enc_final_state))
beam_width = self.config.training_beam_size if training else self.config.beam_size
#cell_dec = ParentFeedingCellWrapper(cell_dec, tf.contrib.seq2seq.tile_batch(enc_final_state, beam_width))
if self.config.apply_attention:
attention = LuongAttention(decoder_hidden_size,
tf.contrib.seq2seq.tile_batch(enc_hidden_states, beam_width),
tf.contrib.seq2seq.tile_batch(self.input_length_placeholder, beam_width),
probability_fn=tf.nn.softmax)
cell_dec = AttentionWrapper(cell_dec, attention,
cell_input_fn=lambda inputs, _: inputs,
attention_layer_size=decoder_hidden_size,
initial_cell_state=tf.contrib.seq2seq.tile_batch(enc_final_state, beam_width))
enc_final_state = cell_dec.zero_state(self.batch_size * beam_width, dtype=tf.float32)
else:
enc_final_state = tf.contrib.seq2seq.tile_batch(enc_final_state, beam_width)
print('enc_final_state', enc_final_state)
linear_layer = tf_core_layers.Dense(self.config.output_size)
go_vector = tf.ones((self.batch_size,), dtype=tf.int32) * self.config.grammar.start
decoder = BeamSearchOptimizationDecoder(training, cell_dec, output_embed_matrix, go_vector, self.config.grammar.end,
enc_final_state,
beam_width=beam_width,
output_layer=linear_layer,
gold_sequence=self.output_placeholder if training else None,
gold_sequence_length=(self.output_length_placeholder+1) if training else None)
if self.config.use_grammar_constraints:
raise NotImplementedError("Grammar constraints are not implemented for the beam search yet")
final_outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder, output_time_major=True, maximum_iterations=self.config.max_length)
return final_outputs
beam_aligner.py 文件源码
python
阅读 28
收藏 0
点赞 0
评论 0
评论列表
文章目录