base_controller.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:EAS 作者: han-cai 项目源码 文件源码
def build(self):
        self._define_input()

        output = self.input_seq
        output = embedding(output, self.vocab.size, self.embedding_dim, name='layer_embedding')
        input_dim = self.embedding_dim

        # Prepare data shape to match rnn function requirements
        # Current data input shape: [batch_size, num_steps, input_dim]
        # Required shape: 'num_steps' tensors list of shape [batch_size, input_dim]
        output = tf.transpose(output, [1, 0, 2])
        output = tf.reshape(output, [-1, input_dim])
        output = tf.split(output, self.num_steps, 0)

        if self.bidirectional:
            # 'num_steps' tensors list of shape [batch_size, rnn_units * 2]
            fw_cell = build_cell(self.rnn_units, self.cell_type, self.rnn_layers)
            bw_cell = build_cell(self.rnn_units, self.cell_type, self.rnn_layers)
            output, state_fw, state_bw = rnn.static_bidirectional_rnn(
                fw_cell, bw_cell, output, dtype=tf.float32, sequence_length=self.seq_len, scope='encoder')

            if isinstance(state_fw, tf.contrib.rnn.LSTMStateTuple):
                encoder_state_c = tf.concat([state_fw.c, state_bw.c], axis=1, name='bidirectional_concat_c')
                encoder_state_h = tf.concat([state_fw.h, state_bw.h], axis=1, name='bidirectional_concat_h')
                state = tf.contrib.rnn.LSTMStateTuple(c=encoder_state_c, h=encoder_state_h)
            elif isinstance(state_fw, tf.Tensor):
                state = tf.concat([state_fw, state_bw], axis=1, name='bidirectional_concat')
            else:
                raise ValueError
        else:
            # 'num_steps' tensors list of shape [batch_size, rnn_units]
            cell = build_cell(self.rnn_units, self.cell_type, self.rnn_layers)
            output, state = rnn.static_rnn(cell, output, dtype=tf.float32, sequence_length=self.seq_len,
                                           scope='encoder')

        output = tf.stack(output, axis=0)  # [num_steps, batch_size, rnn_units]
        output = tf.transpose(output, [1, 0, 2])  # [batch_size, num_steps, rnn_units]
        self.encoder_output = output
        self.encoder_state = state
        return output, state
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号