encoders.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:almond-nnparser 作者: Stanford-Mobisocial-IoT-Lab 项目源码 文件源码
def encode(self, inputs, input_length, _parses):
        with tf.name_scope('BiLSTMEncoder'):
            fw_cell_enc = tf.contrib.rnn.MultiRNNCell([self._make_rnn_cell(i) for i in range(self._num_layers)])
            bw_cell_enc = tf.contrib.rnn.MultiRNNCell([self._make_rnn_cell(i) for i in range(self._num_layers)])

            outputs, output_state = tf.nn.bidirectional_dynamic_rnn(fw_cell_enc, bw_cell_enc, inputs, input_length,
                                                                    dtype=tf.float32)

            fw_output_state, bw_output_state = output_state
            # concat each element of the final state, so that we're compatible with a unidirectional
            # decoder
            output_state = nest.pack_sequence_as(fw_output_state, [tf.concat((x, y), axis=1) for x, y in zip(nest.flatten(fw_output_state), nest.flatten(bw_output_state))])

            return tf.concat(outputs, axis=2), output_state
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号