def lstmoutput(self, model_input, vocab_size, num_frames):
number_of_layers = FLAGS.lstm_layers
lstm_sizes = map(int, FLAGS.lstm_cells.split(","))
feature_names, feature_sizes = utils.GetListOfFeatureNamesAndSizes(
FLAGS.feature_names, FLAGS.feature_sizes)
sub_inputs = [tf.nn.l2_normalize(x, dim=2) for x in tf.split(model_input, feature_sizes, axis = 2)]
assert len(lstm_sizes) == len(feature_sizes), \
"length of lstm_sizes (={}) != length of feature_sizes (={})".format( \
len(lstm_sizes), len(feature_sizes))
outputs = []
for i in xrange(len(feature_sizes)):
with tf.variable_scope("RNN%d" % i):
sub_input = sub_inputs[i]
lstm_size = lstm_sizes[i]
## Batch normalize the input
stacked_lstm = tf.contrib.rnn.MultiRNNCell(
[
tf.contrib.rnn.BasicLSTMCell(
lstm_size, forget_bias=1.0, state_is_tuple=True)
for _ in range(number_of_layers)
],
state_is_tuple=True)
output, state = tf.nn.dynamic_rnn(stacked_lstm, sub_input,
sequence_length=num_frames,
swap_memory=FLAGS.rnn_swap_memory,
dtype=tf.float32)
outputs.append(output)
# concat
final_output = tf.concat(outputs, axis=2)
return final_output
lstm_cnn_deep_combine_chain_model.py 文件源码
python
阅读 24
收藏 0
点赞 0
评论 0
评论列表
文章目录