def __init__(self,config):
self.c_bp_lstm=context_bottom_up_lstm(config)
self.inputs=self.c_bp_lstm.sentences_root_states
self.inputs=tf.expand_dims(self.inputs, 0) #[1 , sentence_num, hidden_dim]
self.sentence_num=tf.gather(tf.shape(self.inputs),1)
self.sentence_num_batch=tf.expand_dims(self.sentence_num, 0) #[1]
with tf.variable_scope('context_lstm_forward'):
self.fwcell=rnn.BasicLSTMCell(config.hidden_dim, activation=tf.nn.tanh)
with tf.variable_scope('context_lstm_backward'):
self.bwcell=rnn.BasicLSTMCell(config.hidden_dim, activation=tf.nn.tanh)
with tf.variable_scope('context_bidirectional_chain_lstm'):
self._fw_initial_state=self.fwcell.zero_state(1,dtype=tf.float32)
self._bw_initial_state=self.bwcell.zero_state(1,dtype=tf.float32)
chain_outputs, chain_state=tf.nn.bidirectional_dynamic_rnn(self.fwcell, self.bwcell, self.inputs, self.sentence_num_batch, initial_state_fw=self._fw_initial_state, initial_state_bw=self._bw_initial_state)
chain_outputs=tf.concat(chain_outputs, 2) #[1, sentence_num, 2*hidden_dim]
chain_outputs=tf.gather(chain_outputs, 0) #[sentence_num, 2*hidden_dim]
self.c_td_lstm=context_top_down_lstm(config, self.c_bp_lstm, chain_outputs)
self.sentences_final_states=self.get_tree_states(self.c_bp_lstm.sentences_hidden_states, self.c_td_lstm.sentences_hidden_states)
context_encoding.py 文件源码
python
阅读 22
收藏 0
点赞 0
评论 0
评论列表
文章目录