def reset_states(self):
print("begin reset_states(self)")
assert self.stateful, 'Layer must be stateful.'
input_shape = self.input_spec[0].shape
self.depth = 0
if not input_shape[0]:
raise Exception('If a RNN is stateful, a complete ' +
'input_shape must be provided (including batch size).')
if hasattr(self, 'states'):
# K.set_value(self.states[0],
# np.zeros((input_shape[0], self.output_dim)))
# K.set_value(self.states[1],
# np.zeros((input_shape[0], self.output_dim)))
# add by Robot Steven ****************************************#
# previous inner memory
K.set_value(self.states[0],
np.zeros((input_shape[0], self.controller_output_dim)))
# previous inner cell
K.set_value(self.states[1],
np.zeros((input_shape[0], self.controller_output_dim)))
# previous memory
K.set_value(self.states[2],
np.zeros((input_shape[0], self.memory_dim * self.memory_size)))
# K.set_value(self.states[2],
# np.zeros((input_shape[0], self.memory_size, self.memory_dim)))
# previous writing addresses
K.set_value(self.states[3],
np.zeros((input_shape[0], self.num_write_head * self.memory_size)))
# K.set_value(self.states[3],
# np.zeros((input_shape[0], self.num_write_head * self.memory_size)))
# previous reading addresses
K.set_value(self.states[4],
np.zeros((input_shape[0], self.num_read_head * self.memory_size)))
# previous reading content
K.set_value(self.states[5],
np.zeros((input_shape[0], self.num_read_head * self.memory_dim)))
# add by Robot Steven ****************************************#
else:
# self.states = [K.zeros((input_shape[0], self.output_dim)),
# K.zeros((input_shape[0], self.output_dim))]
# add by Robot Steven ****************************************#
self.states = [K.zeros((input_shape[0], self.controller_output_dim)), # h_tm1
K.zeros((input_shape[0], self.controller_output_dim)), # c_tm1]
K.zeros((input_shape[0], self.memory_dim * self.memory_size)),
# K.zeros((input_shape[0], self.memory_size, self.memory_dim)),
K.zeros((input_shape[0], self.num_write_head * self.memory_size)),
K.zeros((input_shape[0], self.num_read_head * self.memory_size)),
K.zeros((input_shape[0], self.num_read_head * self.memory_dim))]
# add by Robot Steven ****************************************#
print("end reset_states(self)\n")
评论列表
文章目录