rnnrbm.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:keras_bn_library 作者: bnsnapper 项目源码 文件源码
def reset_states(self):
        assert self.stateful, 'Layer must be stateful.'
        input_shape = self.input_spec[0].shape

        if not input_shape[0]:
            raise Exception('If a RNN is stateful, a complete ' +
                            'input_shape must be provided (including batch size).')

        if hasattr(self, 'states'):
            K.set_value(self.states[0],
                        np.zeros((input_shape[0], self.hidden_recurrent_dim)))
            K.set_value(self.states[1],
                        np.zeros((input_shape[0], self.input_dim)))
            K.set_value(self.states[2],
                        np.zeros((input_shape[0], self.hidden_dim)))
        else:
            self.states = [K.zeros((input_shape[0], self.hidden_recurrent_dim)),
                            K.zeros((input_shape[0], self.input_dim)),
                            K.zeros((input_shape[0], self.hidden_dim))]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号