model.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:merlin 作者: CSTR-Edinburgh 项目源码 文件源码
def decoder(self,decoder_inputs,enc_output,enc_states,target_sequence_length):
          """Memory is a tuple containing the forward and backward final states (output_states_fw,output_states_bw)"""
          with tf.variable_scope("decoder"):
              basic_cell=[]
              for i in xrange(len(self.hidden_layer_size)):
                    if self.hidden_layer_type[i]=="tanh":
                        basic_cell.append(tf.contrib.rnn.BasicRNNCell(num_units=self.encoder_layer_size[i]))
                    if self.hidden_layer_type[i]=="lstm":
                        basic_cell.append(tf.contrib.rnn.BasicLSTMCell(num_units=self.encoder_layer_size[i]))
                    if self.hidden_layer_type[i]=="gru":
                         basic_cell.append(GRUCell(num_units=self.encoder_layer_size[i]))
              multicell=MultiRNNCell(basic_cell)
          if not self.attention:
              dec_output,_=tf.nn.bidirectional_dynamic_rnn(cell_fw=multicell,cell_bw=multicell,inputs=decoder_inputs,initial_state_fw=enc_states[0],\
                                                           sequence_length=target_sequence_length,initial_state_bw=enc_states[1])
          else:
              attention_size=decoder_inputs.get_shape().as_list()[-1]
              attention_mechanism=tf.contrib.seq2seq.BahdanauAttention(attention_size,enc_output,target_sequence_length,normalize=True,probability_fn=tf.nn.softmax)
              cell_with_attention=tf.contrib.seq2seq.AttentionWrapper(multicell,attention_mechanism,attention_size)
              dec_output,_=tf.nn.bidirectional_dynamic_rnn(cell_fw=cell_with_attention,cell_bw=cell_with_attention,inputs=decoder_inputs,dtype=tf.float32)
          return dec_output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号