ln_lstm2.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:Multi-channel-speech-extraction-using-DNN 作者: zhr1201 项目源码 文件源码
def __call__(self, inputs, state, scope=None):
        """Long short-term memory cell (LSTM)."""
        with tf.variable_scope(scope or type(self).__name__):
            c, h = state

            # change bias argument to False since LN will add bias via shift
            concat = tf.nn.rnn_cell._linear(
                [inputs, h], 4 * self._num_units, False)
            # ipdb.set_trace()

            i, j, f, o = tf.split(1, 4, concat)

            # add layer normalization to each gate
            i = ln(i, scope='i/')
            j = ln(j, scope='j/')
            f = ln(f, scope='f/')
            o = ln(o, scope='o/')

            new_c = (c * tf.nn.sigmoid(f + self._forget_bias) +
                     tf.nn.sigmoid(i) * self._activation(j))

            # add layer_normalization in calculation of new hidden state
            new_h = self._activation(
                ln(new_c, scope='new_h/')) * tf.nn.sigmoid(o)
            new_state = tf.nn.rnn_cell.LSTMStateTuple(new_c, new_h)
            return new_h, new_state
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号