15-keras_seq2seq_mod.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:albemarle 作者: SeanTater 项目源码 文件源码
def get_state_transfer_rnn(RNN):
    '''Converts a given Recurrent sub class (e.g, LSTM, GRU) to its state transferable version.
    A state transfer RNN can transfer its hidden state to another one of the same type and compatible dimensions.
    '''

    class StateTransferRNN(RNN):

        def __init__(self, state_input=True, **kwargs):
            self.state_outputs = []
            self.state_input = state_input
            super(StateTransferRNN, self).__init__(**kwargs)

        def reset_states(self):
            stateful = self.stateful
            self.stateful = stateful or self.state_input or len(self.state_outputs) > 0
            if self.stateful:
                super(StateTransferRNN, self).reset_states()
            self.stateful = stateful

        def build(self,input_shape):
            stateful = self.stateful
            self.stateful = stateful or self.state_input or len(self.state_outputs) > 0
            super(StateTransferRNN, self).build(input_shape)
            self.stateful = stateful

        def broadcast_state(self, rnns):
            rnns = (set if type(rnns) in [list, tuple] else lambda a: {a})(rnns)
            rnns -= set(self.state_outputs)
            self.state_outputs.extend(rnns)
            for rnn in rnns:
                rnn.state_input = self
                rnn.updates = getattr(rnn, 'updates', [])
                rnn.updates.extend(zip(rnn.states, self.states_to_transfer))

        def call(self, x, mask=None):
            last_output, outputs, states = K.rnn(
                self.step,
                self.preprocess_input(x),
                self.states or self.get_initial_states(x),
                go_backwards=self.go_backwards,
                mask=mask,
                constants=self.get_constants(x),
                unroll=self.unroll,
                input_length=self.input_spec[0].shape[1])
            self.updates = zip(self.states, states)
            self.states_to_transfer = states
            return outputs if self.return_sequences else last_output
    return StateTransferRNN
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号