nse.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:onto-lstm 作者: pdasigi 项目源码 文件源码
def call(self, x, mask=None):
        # input_shape = (batch_size, input_length, input_dim). This needs to be defined in build.
        initial_read_states = self.get_initial_states(x, mask)
        fake_writer_input = K.expand_dims(initial_read_states[0], dim=1)  # (batch_size, 1, output_dim)
        initial_write_states = self.writer.get_initial_states(fake_writer_input)  # h_0 and c_0 of the writer LSTM
        initial_states = initial_read_states + initial_write_states
        # last_output: (batch_size, output_dim)
        # all_outputs: (batch_size, input_length, output_dim)
        # last_states:
        #       last_memory_state: (batch_size, input_length, output_dim)
        #       last_output
        #       last_writer_ct
        last_output, all_outputs, last_states = self.loop(x, initial_states, mask)
        last_memory = last_states[0]
        if self.return_mode == "last_output":
            return last_output
        elif self.return_mode == "all_outputs":
            return all_outputs
        else:
            # return mode is output_and_memory
            expanded_last_output = K.expand_dims(last_output, dim=1)  # (batch_size, 1, output_dim)
            # (batch_size, 1+input_length, output_dim)
            return K.concatenate([expanded_last_output, last_memory], axis=1)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号