rnn_base.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:auDeep 作者: auDeep 项目源码 文件源码
def initial_states_tuple(self):
        """
        Create the initial state tensors for the individual RNN cells.

        If no initial state vector was passed to this RNN, all initial states are set to be zero. Otherwise, the initial
        state vector is split into a possibly nested tuple of tensors according to the RNN architecture. The return
        value of this function is structured in such a way that it can be passed to the `initial_state` parameter of the
        RNN functions in `tf.contrib.rnn`.

        Returns
        -------
        tuple of tf.Tensor
            A possibly nested tuple of initial state tensors for the RNN cells
        """
        if self.initial_state is None:
            initial_states = tf.zeros(shape=[self.batch_size, self.state_size], dtype=tf.float32)
        else:
            initial_states = self.initial_state

        initial_states = tuple(tf.split(initial_states, self.num_layers, axis=1))

        if self.bidirectional:
            initial_states = tuple([tf.split(x, 2, axis=1) for x in initial_states])
            initial_states_fw, initial_states_bw = zip(*initial_states)

            if self.cell_type == CellType.LSTM:
                initial_states_fw = tuple([LSTMStateTuple(*tf.split(lstm_state, 2, axis=1))
                                           for lstm_state in initial_states_fw])
                initial_states_bw = tuple([LSTMStateTuple(*tf.split(lstm_state, 2, axis=1))
                                           for lstm_state in initial_states_bw])

            initial_states = (initial_states_fw, initial_states_bw)
        else:
            if self.cell_type == CellType.LSTM:
                initial_states = tuple([LSTMStateTuple(*tf.split(lstm_state, 2, axis=1))
                                        for lstm_state in initial_states])

        return initial_states
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号