nse.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:onto-lstm 作者: pdasigi 项目源码 文件源码
def step(self, input_t, states):
        reader_states = states[:2]
        flattened_mem_tm1, flattened_shared_mem_tm1 = states[2:4]
        writer_h_tm1, writer_c_tm1 = states[4:]
        input_mem_shape = K.shape(flattened_mem_tm1)
        mem_shape = (input_mem_shape[0], input_mem_shape[1]/self.output_dim, self.output_dim)
        mem_tm1 = K.reshape(flattened_mem_tm1, mem_shape)
        shared_mem_tm1 = K.reshape(flattened_shared_mem_tm1, mem_shape)
        reader_constants = self.reader.get_constants(input_t)
        reader_states += reader_constants
        o_t, [_, reader_c_t] = self.reader.step(input_t, reader_states)
        z_t, m_rt = self.summarize_memory(o_t, mem_tm1)
        shared_z_t, shared_m_rt = self.summarize_memory(o_t, shared_mem_tm1)
        c_t = self.compose_memory_and_output([o_t, m_rt, shared_m_rt])
        # Collecting the necessary variables to directly call writer's step function.
        writer_constants = self.writer.get_constants(c_t)  # returns dropouts for W and U (all 1s, see init)
        writer_states = [writer_h_tm1, writer_c_tm1] + writer_constants
        # Making a call to writer's step function, Equation 5
        h_t, [_, writer_c_t] = self.writer.step(c_t, writer_states)  # h_t, writer_c_t: (batch_size, output_dim)
        mem_t = self.update_memory(z_t, h_t, mem_tm1)
        shared_mem_t = self.update_memory(shared_z_t, h_t, shared_mem_tm1)
        return h_t, [o_t, reader_c_t, K.batch_flatten(mem_t), K.batch_flatten(shared_mem_t), h_t, writer_c_t]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号