def reset_states(self):
assert self.stateful, 'Layer must be stateful.'
input_shape = self.input_shape
if not input_shape[0]:
raise Exception('If a RNN is stateful, a complete ' +
'input_shape must be provided ' +
'(including batch size).')
if self.return_sequences:
out_row, out_col, out_filter = self.output_shape[2:]
else:
out_row, out_col, out_filter = self.output_shape[1:]
if hasattr(self, 'states'):
K.set_value(self.states[0],
np.zeros((input_shape[0],
out_row, out_col, out_filter)))
K.set_value(self.states[1],
np.zeros((input_shape[0],
out_row, out_col, out_filter)))
else:
self.states = [K.zeros((input_shape[0],
out_row, out_col, out_filter)),
K.zeros((input_shape[0],
out_row, out_col, out_filter))]
recurrent_convolutional.py 文件源码
python
阅读 21
收藏 0
点赞 0
评论 0
评论列表
文章目录