def decoder(self, decoder_inputs, encoder_state, name, lengths= None, train = True):
dec_cell = tf.contrib.rnn.GRUCell(self.para.embedding_size)
W = self.graph.get_tensor_by_name(name+'/weight:0')
b = self.graph.get_tensor_by_name(name+'/bias:0')
if train:
with tf.variable_scope(name) as varscope:
dynamic_fn_train = tf.contrib.seq2seq.simple_decoder_fn_train(encoder_state)
outputs_train, state_train, _ = tf.contrib.seq2seq.dynamic_rnn_decoder(dec_cell, decoder_fn = dynamic_fn_train,
inputs=decoder_inputs, sequence_length = lengths, scope = varscope)
logits = tf.reshape(outputs_train, [-1, self.para.embedding_size])
logits_train = tf.matmul(logits, W) + b
logits_projected = tf.reshape(logits_train, [self.para.batch_size, tf.reduce_max(lengths), self.vocabulary_size])
return logits_projected, outputs_train
else:
with tf.variable_scope(name, reuse = True) as varscope:
output_fn = lambda x: tf.nn.softmax(tf.matmul(x, W) + b)
dynamic_fn_inference = tf.contrib.seq2seq.simple_decoder_fn_inference(output_fn =output_fn, encoder_state = encoder_state,
embeddings = self.word_embeddings, start_of_sequence_id = 2, end_of_sequence_id = 3, maximum_length = self.max_sent_len, num_decoder_symbols = self.vocabulary_size)
logits_inference, state_inference,_ = tf.contrib.seq2seq.dynamic_rnn_decoder(dec_cell, decoder_fn = dynamic_fn_inference, scope = varscope)
return tf.arg_max(logits_inference, 2)
评论列表
文章目录