def apply(self, inputs, update_inputs, reset_inputs, mask=None):
def step(inputs, update_inputs, reset_inputs, states, state_to_update, state_to_reset, state_to_state):
#import ipdb
#ipdb.set_trace()
reset_values = self.gate_activation.apply(
states.dot(self.state_to_reset) + reset_inputs)
update_values = self.gate_activation.apply(
states.dot(self.state_to_update) + update_inputs)
next_states_proposed = self.activation.apply(
(states * reset_values).dot(self.state_to_state) + inputs)
next_states = (next_states_proposed * update_values +
states * (1 - update_values))
return next_states
def step_mask(inputs, update_inputs, reset_inputs, mask_input, states, state_to_update, state_to_reset, state_to_state):
next_states = step(inputs, updatE_inputs, reset_inputs, states, state_to_update, state_to_reset, state_to_state)
if mask_input:
next_states = (mask_input[:, None] * next_states +
(1 - mask_input[:, None]) * states)
return next_states
if mask:
func = step_mask
sequences = [inputs, update_inputs, reset_inputs, mask]
else:
func = step
sequences = [inputs, update_inputs, reset_inputs]
#[dict(input=inputs), dict(input=gate_inputs), dict(input=mask)]
#output = tensor.repeat(self.params[2].dimshuffle('x',0), inputs.shape[1], axis=0)
states_output, _ = theano.scan(fn=func,
sequences=sequences,
outputs_info=[self.initial_state('initial_state', inputs.shape[1])],
non_sequences=[self.state_to_reset, self.state_to_update, self.state_to_state],
strict=True,
allow_gc=False)
return states_output
评论列表
文章目录