def reset_states(self):
assert self.stateful, 'Layer must be stateful.'
input_shape = self.input_spec[0].shape
if not input_shape[0]:
raise Exception('If a RNN is stateful, it needs to know '
'its batch size. Specify the batch size '
'of your input tensors: \n'
'- If using a Sequential model, '
'specify the batch size by passing '
'a `batch_input_shape` '
'argument to your first layer.\n'
'- If using the functional API, specify '
'the time dimension by passing a '
'`batch_shape` argument to your Input layer.')
if hasattr(self, 'states'):
K.set_value(self.states[0],
np.zeros((input_shape[0], self.output_dim)))
else:
self.states = [K.zeros((input_shape[0], self.output_dim))]
layer_normalization_RNN.py 文件源码
python
阅读 25
收藏 0
点赞 0
评论 0
评论列表
文章目录