10-lstm-tensorflow-char-pat.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:albemarle 作者: SeanTater 项目源码 文件源码
def RNN(inputs, lens, name, reuse):
    print ("Building network " + name)
    # Define weights
    inputs = tf.gather(one_hots, inputs)
    weights = tf.Variable(tf.random_normal([__n_hidden, n_output]), name=name+"_weights")
    biases = tf.Variable(tf.random_normal([n_output]), name=name+"_biases")

    # Define a lstm cell with tensorflow
    outputs, states = rnn.dynamic_rnn(
        __cell_kind(__n_hidden),
        inputs,
        sequence_length=lens,
        dtype=tf.float32,
        scope=name,
        time_major=False)
    assert outputs.get_shape() == (__batch_size, __n_steps, __n_hidden)
    print ("Done building network " + name)

    #
    # All these asserts are actually documentation: they can't be out of date
    #

    outputs = tf.expand_dims(outputs, 2)
    assert outputs.get_shape() == (__batch_size, __n_steps, 1, __n_hidden)

    tiled_weights = tf.tile(tf.expand_dims(tf.expand_dims(weights, 0), 0), [__batch_size, __n_steps, 1, 1])
    assert tiled_weights.get_shape() == (__batch_size, __n_steps, __n_hidden, n_output)
    #assert tiled_weights.get_shape() == (1, 1, __n_hidden, n_output)
    # Linear activation, using rnn inner loop output for each char
    finals = tf.batch_matmul(outputs, tiled_weights) + biases
    assert finals.get_shape() == (__batch_size, __n_steps, 1, n_output)
    return tf.squeeze(finals)

# tf Graph input
#pat_chars = tf.placeholder(tf.float32, [__batch_size, __n_steps, n_input])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号