network.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:meta-learning 作者: ioanachelu 项目源码 文件源码
def fast_dlstm(self, s_t, state_in, lstm, chunks, h_size):

        def get_sub_state(state, state_step):
            c, h = state
            chunk_step_size = h_size // chunks
            h_step = state_step * chunk_step_size
            sub_state_h = h[:, h_step: h_step + chunk_step_size]
            sub_state_c = c[:, h_step: h_step + chunk_step_size]
            sub_state_h.set_shape([1, chunk_step_size])
            sub_state_c.set_shape([1, chunk_step_size])
            sub_state = tf.contrib.rnn.LSTMStateTuple(sub_state_c, sub_state_h)
            return sub_state

        def build_new_state(new_sub_state, previous_state, state_step):
            c_previous_state, h_previous_state = previous_state
            c_new_sub_state, h_new_sub_state = new_sub_state
            h_slices = []
            c_slices = []
            chunk_step_size = h_size // chunks
            one_hot_state_step = tf.one_hot(state_step, depth=chunks)

            for switch_step, h_step in zip(range(chunks), range(0, h_size, chunk_step_size)):
                is_this_current_step = one_hot_state_step[switch_step]
                h_s = self.conditional_sub_state(is_this_current_step, h_new_sub_state,
                                                 h_previous_state[:, h_step: h_step + chunk_step_size])
                h_s.set_shape([1, chunk_step_size])
                c_s = self.conditional_sub_state(is_this_current_step,
                                                 c_new_sub_state,
                                                 c_previous_state[:, h_step: h_step + chunk_step_size])
                c_s.set_shape([1, chunk_step_size])
                h_slices.append(h_s)
                c_slices.append(c_s)
            h_new_state = tf.concat(h_slices, axis=1)
            c_new_state = tf.concat(c_slices, axis=1)
            new_state = tf.contrib.rnn.LSTMStateTuple(c_new_state, h_new_state)
            return new_state

        def dlstm_scan_fn(previous_output, current_input):
            # out, state_out = lstm(current_input, previous_output[1])
            state_step = previous_output[2]

            sub_state = get_sub_state(previous_output[1], state_step)
            out, sub_state_out = lstm(current_input, sub_state)
            state_out = build_new_state(sub_state_out, previous_output[1], state_step)
            state_step += tf.constant(1)
            new_state_step = tf.mod(state_step, chunks)


            return out, state_out, new_state_step

        chunk_step_size = h_size // chunks
        first_input = state_in.c[:, 0: chunk_step_size]
        rnn_outputs, final_states, mod_idxs = tf.scan(dlstm_scan_fn,
                                                      tf.transpose(s_t, [1, 0, 2]),
                                                      initializer=(
                                                          first_input, state_in, tf.constant(0)), name="dlstm")

        return rnn_outputs, final_states
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号