def get_state_transfer_rnn(RNN):
'''Converts a given Recurrent sub class (e.g, LSTM, GRU) to its state transferable version.
A state transfer RNN can transfer its hidden state to another one of the same type and compatible dimensions.
'''
class StateTransferRNN(RNN):
def __init__(self, state_input=True, **kwargs):
self.state_outputs = []
self.state_input = state_input
super(StateTransferRNN, self).__init__(**kwargs)
def reset_states(self):
stateful = self.stateful
self.stateful = stateful or self.state_input or len(self.state_outputs) > 0
if self.stateful:
super(StateTransferRNN, self).reset_states()
self.stateful = stateful
def build(self,input_shape):
stateful = self.stateful
self.stateful = stateful or self.state_input or len(self.state_outputs) > 0
super(StateTransferRNN, self).build(input_shape)
self.stateful = stateful
def broadcast_state(self, rnns):
rnns = (set if type(rnns) in [list, tuple] else lambda a: {a})(rnns)
rnns -= set(self.state_outputs)
self.state_outputs.extend(rnns)
for rnn in rnns:
rnn.state_input = self
rnn.updates = getattr(rnn, 'updates', [])
rnn.updates.extend(zip(rnn.states, self.states_to_transfer))
def call(self, x, mask=None):
last_output, outputs, states = K.rnn(
self.step,
self.preprocess_input(x),
self.states or self.get_initial_states(x),
go_backwards=self.go_backwards,
mask=mask,
constants=self.get_constants(x),
unroll=self.unroll,
input_length=self.input_spec[0].shape[1])
self.updates = zip(self.states, states)
self.states_to_transfer = states
return outputs if self.return_sequences else last_output
return StateTransferRNN
评论列表
文章目录