def __call__(self, X):
"""
Performs the LSTM's forget, input and output operations
according to: http://arxiv.org/pdf/1402.1128v1.pdf without peepholes
Parameters:
----------
X: list[Tensor]
The input list to process by the LSTM
"""
outputs = tf.TensorArray(tf.float32, len(X))
inputs = tf.TensorArray(tf.float32, len(X))
t = tf.constant(0, dtype=tf.int32)
for i, step_input in enumerate(X):
inputs = inputs.write(i, step_input)
def step_op(time, prev_state, prev_output, inputs_list, outputs_list):
time_step = inputs_list.read(time)
gates = tf.matmul(time_step, self.input_weights) + tf.matmul(prev_output, self.output_weights) + self.bias
gates = tf.reshape(gates, [-1, self.num_hidden, 4])
input_gate = tf.sigmoid(gates[:, :, 0])
forget_gate = tf.sigmoid(gates[:, :, 1])
candidate_state = tf.tanh(gates[:, :, 2])
output_gate = tf.sigmoid(gates[:, :, 3])
state = forget_gate * prev_state + input_gate * candidate_state
output = output_gate * tf.tanh(state)
new_outputs = outputs_list.write(time, output)
return time + 1, state, output, inputs_list, new_outputs
_, state, output, _, final_outputs = tf.while_loop(
cond=lambda time, *_: time < len(X),
body= step_op,
loop_vars=(t, self.prev_state, self.prev_output, inputs, outputs),
parallel_iterations=32,
swap_memory=True
)
self.prev_state.assign(state)
self.prev_output.assign(output)
return [final_outputs.read(t) for t in range(len(X))]
评论列表
文章目录