def _build_lstm(self, input_state):
initial_lstm_state = tf.placeholder(
tf.float32, [None, 2*self.hidden_state_size], name='initital_state')
lstm_cell = BasicLSTMCell(self.hidden_state_size, forget_bias=1.0, state_is_tuple=True)
batch_size = tf.shape(self.step_size)[0]
ox_reshaped = tf.reshape(input_state,
batch_size, -1, input_state.get_shape().as_list()[-1]])
lstm_outputs, lstm_state = tf.nn.dynamic_rnn(
lstm_cell,
ox_reshaped,
initial_state=initial_lstm_state,
sequence_length=self.step_size,
time_major=False)
评论列表
文章目录