ntm.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:NTM-Keras 作者: SigmaQuan 项目源码 文件源码
def reset_states(self):
        print("begin reset_states(self)")
        assert self.stateful, 'Layer must be stateful.'
        input_shape = self.input_spec[0].shape
        self.depth = 0
        if not input_shape[0]:
            raise Exception('If a RNN is stateful, a complete ' +
                            'input_shape must be provided (including batch size).')
        if hasattr(self, 'states'):
            # K.set_value(self.states[0],
            #             np.zeros((input_shape[0], self.output_dim)))
            # K.set_value(self.states[1],
            #             np.zeros((input_shape[0], self.output_dim)))
            # add by Robot Steven ****************************************#
            # previous inner memory
            K.set_value(self.states[0],
                        np.zeros((input_shape[0], self.controller_output_dim)))
            # previous inner cell
            K.set_value(self.states[1],
                        np.zeros((input_shape[0], self.controller_output_dim)))
            # previous memory
            K.set_value(self.states[2],
                        np.zeros((input_shape[0], self.memory_dim * self.memory_size)))
            # K.set_value(self.states[2],
            #             np.zeros((input_shape[0], self.memory_size, self.memory_dim)))
            # previous writing addresses
            K.set_value(self.states[3],
                        np.zeros((input_shape[0], self.num_write_head * self.memory_size)))
            # K.set_value(self.states[3],
            #             np.zeros((input_shape[0], self.num_write_head * self.memory_size)))
            # previous reading addresses
            K.set_value(self.states[4],
                        np.zeros((input_shape[0], self.num_read_head * self.memory_size)))
            # previous reading content
            K.set_value(self.states[5],
                        np.zeros((input_shape[0], self.num_read_head * self.memory_dim)))
            # add by Robot Steven ****************************************#
        else:
            # self.states = [K.zeros((input_shape[0], self.output_dim)),
            #                K.zeros((input_shape[0], self.output_dim))]
            # add by Robot Steven ****************************************#
            self.states = [K.zeros((input_shape[0], self.controller_output_dim)),  # h_tm1
                           K.zeros((input_shape[0], self.controller_output_dim)),  # c_tm1]
                           K.zeros((input_shape[0], self.memory_dim * self.memory_size)),
                           # K.zeros((input_shape[0], self.memory_size, self.memory_dim)),
                           K.zeros((input_shape[0], self.num_write_head * self.memory_size)),
                           K.zeros((input_shape[0], self.num_read_head * self.memory_size)),
                           K.zeros((input_shape[0], self.num_read_head * self.memory_dim))]
            # add by Robot Steven ****************************************#
        print("end reset_states(self)\n")
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号