test_dlstm.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:meta-learning 作者: ioanachelu 项目源码 文件源码
def fast_dlstm(s_t, state_in):
        def dilate_one_time_step(one_h, switcher, num_chunks):
            h_slices = []
            h_size = 256
            chunk_step_size = h_size // num_chunks
            for switch_step, h_step in zip(range(num_chunks), range(0, h_size, chunk_step_size)):
                one_switch = switcher[switch_step]
                h_s = conditional_backprop(one_switch, one_h[h_step: h_step + chunk_step_size])
                h_slices.append(h_s)
            dh = tf.stack(h_slices)
            dh = tf.reshape(dh, [-1, 256])
            return dh

        lstm = rnn.LSTMCell(256, state_is_tuple=True)
        chunks = 8

        def dlstm_scan_fn(previous_output, current_input):
            out, state_out = lstm(current_input, previous_output[1])
            i = previous_output[2]
            basis_i = tf.one_hot(i, depth=chunks)
            state_out_dilated = dilate_one_time_step(tf.squeeze(state_out[0]), basis_i, chunks)
            state_out = rnn.LSTMStateTuple(state_out_dilated, state_out[1])
            i += tf.constant(1)
            new_i = tf.mod(i, chunks)
            return out, state_out, new_i

        rnn_outputs, final_states, mod_idxs = tf.scan(dlstm_scan_fn,
                                                      tf.transpose(s_t, [1, 0, 2]),
                                                      initializer=(
                                                      state_in[1], rnn.LSTMStateTuple(*state_in), tf.constant(0)))

        state_out = [final_states[0][-1, 0, :], final_states[1][-1, 0, :]]
        cell_states = final_states[0][:, 0, :]
        out_states = final_states[1][:, 0, :]
        return out_states, cell_states, state_out
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号