def build_graph(self, x, batch_size=1, n_units=256):
self.phs = [graph.Placeholder(np.float32, [batch_size, n_units]) for _ in range(2)]
self.ph_state = graph.TfNode(tuple(ph.node for ph in self.phs))
self.ph_state.checked = tuple(ph.checked for ph in self.phs)
self.zero_state = tuple(np.zeros([batch_size, n_units]) for _ in range(2))
state = tf.contrib.rnn.LSTMStateTuple(*self.ph_state.checked)
lstm = tf.contrib.rnn.BasicLSTMCell(n_units, state_is_tuple=True)
outputs, self.state = tf.nn.dynamic_rnn(lstm, x.node, initial_state=state,
sequence_length=tf.shape(x.node)[1:2], time_major=False)
self.state = graph.TfNode(self.state)
self.weight = graph.TfNode(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
tf.get_variable_scope().name))
return outputs
评论列表
文章目录