def build_lstm_inner(H, lstm_input):
'''
build lstm decoder
'''
lstm_cell = rnn_cell.BasicLSTMCell(H['lstm_size'], forget_bias=0.0, state_is_tuple=False)
if H['num_lstm_layers'] > 1:
lstm = rnn_cell.MultiRNNCell([lstm_cell] * H['num_lstm_layers'], state_is_tuple=False)
else:
lstm = lstm_cell
batch_size = H['batch_size'] * H['grid_height'] * H['grid_width']
state = tf.zeros([batch_size, lstm.state_size])
outputs = []
with tf.variable_scope('RNN', initializer=tf.random_uniform_initializer(-0.1, 0.1)):
for time_step in range(H['rnn_len']):
if time_step > 0: tf.get_variable_scope().reuse_variables()
output, state = lstm(lstm_input, state)
outputs.append(output)
return outputs
评论列表
文章目录