def _compute_states(self):
_inputs = tf.transpose(self.inputs, [1, 0, 2])
x_ta = tf.TensorArray(tf.float32, size=self.length).unstack(_inputs)
h_ta = tf.TensorArray(tf.float32, size=self.length)
c_ta = tf.TensorArray(tf.float32, size=self.length)
def cond(t, c, h, c_ta, h_ta):
return tf.less(t, self.length)
def body(t, c, h, c_ta, h_ta):
x = x_ta.read(t)
num_units, input_size = self.num_hidden_units, self.input_size
with tf.variable_scope('lstm'):
c_tilde = self.activation(self._linear(h, x, num_units, scope='c'))
i = tf.nn.sigmoid(self._linear(h, x, num_units, scope='i'))
f = tf.nn.sigmoid(self._linear(h, x, num_units, shift=self.optional_bias_shift, scope='f'))
o = tf.nn.sigmoid(self._linear(h, x, num_units, scope='o'))
c_new = i * c_tilde + f * c
h_new = o * self.activation(c_new)
c_ta_new = c_ta.write(t, c_new)
h_ta_new = h_ta.write(t, h_new)
return t + 1, c_new, h_new, c_ta_new, h_ta_new
t = tf.constant(0)
c, h = tf.split(tf.squeeze(self.initial_states, [1]), 2, axis=1)
_, _, _, c_ta, h_ta = tf.while_loop(cond, body, [t, c, h, c_ta, h_ta])
outputs = tf.transpose(h_ta.stack(), [1, 0, 2], name='outputs')
cells = tf.transpose(c_ta.stack(), [1, 0, 2])
states = tf.concat([cells, outputs], axis=2, name='states')
return outputs, states
评论列表
文章目录