def fast_dlstm(self, s_t, state_in, lstm, chunks, h_size):
def get_sub_state(state, state_step):
c, h = state
chunk_step_size = h_size // chunks
h_step = state_step * chunk_step_size
sub_state_h = h[:, h_step: h_step + chunk_step_size]
sub_state_c = c[:, h_step: h_step + chunk_step_size]
sub_state_h.set_shape([1, chunk_step_size])
sub_state_c.set_shape([1, chunk_step_size])
sub_state = tf.contrib.rnn.LSTMStateTuple(sub_state_c, sub_state_h)
return sub_state
def build_new_state(new_sub_state, previous_state, state_step):
c_previous_state, h_previous_state = previous_state
c_new_sub_state, h_new_sub_state = new_sub_state
h_slices = []
c_slices = []
chunk_step_size = h_size // chunks
one_hot_state_step = tf.one_hot(state_step, depth=chunks)
for switch_step, h_step in zip(range(chunks), range(0, h_size, chunk_step_size)):
is_this_current_step = one_hot_state_step[switch_step]
h_s = self.conditional_sub_state(is_this_current_step, h_new_sub_state,
h_previous_state[:, h_step: h_step + chunk_step_size])
h_s.set_shape([1, chunk_step_size])
c_s = self.conditional_sub_state(is_this_current_step,
c_new_sub_state,
c_previous_state[:, h_step: h_step + chunk_step_size])
c_s.set_shape([1, chunk_step_size])
h_slices.append(h_s)
c_slices.append(c_s)
h_new_state = tf.concat(h_slices, axis=1)
c_new_state = tf.concat(c_slices, axis=1)
new_state = tf.contrib.rnn.LSTMStateTuple(c_new_state, h_new_state)
return new_state
def dlstm_scan_fn(previous_output, current_input):
# out, state_out = lstm(current_input, previous_output[1])
state_step = previous_output[2]
sub_state = get_sub_state(previous_output[1], state_step)
out, sub_state_out = lstm(current_input, sub_state)
state_out = build_new_state(sub_state_out, previous_output[1], state_step)
state_step += tf.constant(1)
new_state_step = tf.mod(state_step, chunks)
return out, state_out, new_state_step
chunk_step_size = h_size // chunks
first_input = state_in.c[:, 0: chunk_step_size]
rnn_outputs, final_states, mod_idxs = tf.scan(dlstm_scan_fn,
tf.transpose(s_t, [1, 0, 2]),
initializer=(
first_input, state_in, tf.constant(0)), name="dlstm")
return rnn_outputs, final_states
评论列表
文章目录