def decoder(self,decoder_inputs,enc_output,enc_states,target_sequence_length):
"""Memory is a tuple containing the forward and backward final states (output_states_fw,output_states_bw)"""
with tf.variable_scope("decoder"):
basic_cell=[]
for i in xrange(len(self.hidden_layer_size)):
if self.hidden_layer_type[i]=="tanh":
basic_cell.append(tf.contrib.rnn.BasicRNNCell(num_units=self.encoder_layer_size[i]))
if self.hidden_layer_type[i]=="lstm":
basic_cell.append(tf.contrib.rnn.BasicLSTMCell(num_units=self.encoder_layer_size[i]))
if self.hidden_layer_type[i]=="gru":
basic_cell.append(GRUCell(num_units=self.encoder_layer_size[i]))
multicell=MultiRNNCell(basic_cell)
if not self.attention:
dec_output,_=tf.nn.bidirectional_dynamic_rnn(cell_fw=multicell,cell_bw=multicell,inputs=decoder_inputs,initial_state_fw=enc_states[0],\
sequence_length=target_sequence_length,initial_state_bw=enc_states[1])
else:
attention_size=decoder_inputs.get_shape().as_list()[-1]
attention_mechanism=tf.contrib.seq2seq.BahdanauAttention(attention_size,enc_output,target_sequence_length,normalize=True,probability_fn=tf.nn.softmax)
cell_with_attention=tf.contrib.seq2seq.AttentionWrapper(multicell,attention_mechanism,attention_size)
dec_output,_=tf.nn.bidirectional_dynamic_rnn(cell_fw=cell_with_attention,cell_bw=cell_with_attention,inputs=decoder_inputs,dtype=tf.float32)
return dec_output
评论列表
文章目录