def encoder(self,inputs,inputs_sequence_length):
with tf.variable_scope("encoder"):
basic_cell=[]
for i in xrange(len(self.hidden_layer_size)):
if self.hidden_layer_type[i]=="tanh":
basic_cell.append(tf.contrib.rnn.BasicRNNCell(num_units=self.encoder_layer_size[i]))
if self.hidden_layer_type[i]=="lstm":
basic_cell.append(tf.contrib.rnn.BasicLSTMCell(num_units=self.encoder_layer_size[i]))
if self.hidden_layer_type[i]=="gru":
basic_cell.append(GRUCell(num_units=self.encoder_layer_size[i]))
multicell=MultiRNNCell(basic_cell)
enc_output, enc_state=tf.nn.bidirectional_dynamic_rnn(cell_fw=multicell,cell_bw=multicell,inputs=inputs,\
sequence_length=inputs_sequence_length,dtype=tf.float32)
enc_output=tf.concat(enc_output,2)
#enc_state=(tf.concat(enc_state[0])
return enc_output, enc_state
评论列表
文章目录