def RNN(inputs, lens, name, reuse):
print ("Building network " + name)
# Define weights
inputs = tf.gather(one_hots, inputs)
weights = tf.Variable(tf.random_normal([__n_hidden, n_output]), name=name+"_weights")
biases = tf.Variable(tf.random_normal([n_output]), name=name+"_biases")
# Define a lstm cell with tensorflow
outputs, states = rnn.dynamic_rnn(
__cell_kind(__n_hidden),
inputs,
sequence_length=lens,
dtype=tf.float32,
scope=name,
time_major=False)
assert outputs.get_shape() == (__batch_size, __n_steps, __n_hidden)
print ("Done building network " + name)
#
# All these asserts are actually documentation: they can't be out of date
#
outputs = tf.expand_dims(outputs, 2)
assert outputs.get_shape() == (__batch_size, __n_steps, 1, __n_hidden)
tiled_weights = tf.tile(tf.expand_dims(tf.expand_dims(weights, 0), 0), [__batch_size, __n_steps, 1, 1])
assert tiled_weights.get_shape() == (__batch_size, __n_steps, __n_hidden, n_output)
#assert tiled_weights.get_shape() == (1, 1, __n_hidden, n_output)
# Linear activation, using rnn inner loop output for each char
finals = tf.batch_matmul(outputs, tiled_weights) + biases
assert finals.get_shape() == (__batch_size, __n_steps, 1, n_output)
return tf.squeeze(finals)
# tf Graph input
#pat_chars = tf.placeholder(tf.float32, [__batch_size, __n_steps, n_input])
10-lstm-tensorflow-char-pat.py 文件源码
python
阅读 28
收藏 0
点赞 0
评论 0
评论列表
文章目录