model.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:merlin 作者: CSTR-Edinburgh 项目源码 文件源码
def encoder(self,inputs,inputs_sequence_length):
           with tf.variable_scope("encoder"):
                basic_cell=[]
                for i in xrange(len(self.hidden_layer_size)):
                    if self.hidden_layer_type[i]=="tanh":
                        basic_cell.append(tf.contrib.rnn.BasicRNNCell(num_units=self.encoder_layer_size[i]))
                    if self.hidden_layer_type[i]=="lstm":
                        basic_cell.append(tf.contrib.rnn.BasicLSTMCell(num_units=self.encoder_layer_size[i]))
                    if self.hidden_layer_type[i]=="gru":
                         basic_cell.append(GRUCell(num_units=self.encoder_layer_size[i]))
                multicell=MultiRNNCell(basic_cell)
                enc_output, enc_state=tf.nn.bidirectional_dynamic_rnn(cell_fw=multicell,cell_bw=multicell,inputs=inputs,\
                             sequence_length=inputs_sequence_length,dtype=tf.float32)
                enc_output=tf.concat(enc_output,2)
                #enc_state=(tf.concat(enc_state[0])
                return enc_output, enc_state
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号