def sub_lstm(self, model_input, num_frames, lstm_size, number_of_layers, sub_scope=""):
stacked_lstm = tf.contrib.rnn.MultiRNNCell(
[
tf.contrib.rnn.BasicLSTMCell(
lstm_size, forget_bias=1.0, state_is_tuple=True)
for _ in range(number_of_layers)
],
state_is_tuple=True)
loss = 0.0
with tf.variable_scope(sub_scope+"-RNN"):
outputs, state = tf.nn.dynamic_rnn(stacked_lstm, model_input,
sequence_length=num_frames,
swap_memory=FLAGS.rnn_swap_memory,
dtype=tf.float32)
final_state = tf.concat(map(lambda x: x.c, state), axis = 1)
return final_state
distillchain_lstm_memory_deep_combine_chain_model.py 文件源码
python
阅读 30
收藏 0
点赞 0
评论 0
评论列表
文章目录