context_encoding.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:Constituent-Centric-Neural-Architecture-for-Reading-Comprehension 作者: shrshore 项目源码 文件源码
def __init__(self,config):
        self.c_bp_lstm=context_bottom_up_lstm(config)
        self.inputs=self.c_bp_lstm.sentences_root_states
        self.inputs=tf.expand_dims(self.inputs, 0) #[1 , sentence_num, hidden_dim]
        self.sentence_num=tf.gather(tf.shape(self.inputs),1)
        self.sentence_num_batch=tf.expand_dims(self.sentence_num, 0)  #[1]   
        with tf.variable_scope('context_lstm_forward'): 
            self.fwcell=rnn.BasicLSTMCell(config.hidden_dim, activation=tf.nn.tanh)
        with tf.variable_scope('context_lstm_backward'): 
            self.bwcell=rnn.BasicLSTMCell(config.hidden_dim, activation=tf.nn.tanh)
        with tf.variable_scope('context_bidirectional_chain_lstm'):
            self._fw_initial_state=self.fwcell.zero_state(1,dtype=tf.float32)
            self._bw_initial_state=self.bwcell.zero_state(1,dtype=tf.float32)
            chain_outputs, chain_state=tf.nn.bidirectional_dynamic_rnn(self.fwcell, self.bwcell, self.inputs, self.sentence_num_batch, initial_state_fw=self._fw_initial_state, initial_state_bw=self._bw_initial_state)

        chain_outputs=tf.concat(chain_outputs, 2) #[1, sentence_num, 2*hidden_dim]
        chain_outputs=tf.gather(chain_outputs, 0) #[sentence_num, 2*hidden_dim]

        self.c_td_lstm=context_top_down_lstm(config, self.c_bp_lstm, chain_outputs)
        self.sentences_final_states=self.get_tree_states(self.c_bp_lstm.sentences_hidden_states, self.c_td_lstm.sentences_hidden_states)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号