def reset_states(self):
assert self.stateful, 'Layer must be stateful.'
input_shape = self.input_spec[0].shape
if not input_shape[0]:
raise Exception('If a RNN is stateful, a complete ' +
'input_shape must be provided (including batch size).')
if hasattr(self, 'states'):
K.set_value(self.states[0],
np.zeros((input_shape[0], self.output_dim)))
K.set_value(self.states[1],
np.zeros((input_shape[0], self.output_dim)))
else:
self.states = [K.zeros((input_shape[0], self.output_dim)),
K.zeros((input_shape[0], self.output_dim))]
评论列表
文章目录