train.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:image-classification-rnn 作者: jiegzhan 项目源码 文件源码
def rnn_model(x, weights, biases):
    """RNN (LSTM or GRU) model for image"""
    x = tf.transpose(x, [1, 0, 2])
    x = tf.reshape(x, [-1, n_input])
    x = tf.split(0, n_steps, x)

    lstm_cell = rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0)
    outputs, states = rnn.rnn(lstm_cell, x, dtype=tf.float32)
    return tf.matmul(outputs[-1], weights) + biases
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号