06-lstm-tensorflow.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:albemarle 作者: SeanTater 项目源码 文件源码
def transform_block(tensor):
    # Prepare data shape to match `rnn` function requirements
    # Current data input shape: (batch_size, n_steps, n_input)
    # Required shape: 'n_steps' tensors list of shape (batch_size, n_input)

    # Permuting batch_size and n_steps
    tensor = tf.transpose(tensor, [1, 0, 2])
    # Reshaping to (n_steps*batch_size, n_input)
    tensor = tf.reshape(tensor, [-1, n_input])
    # Split to get a list of 'n_steps' tensors of shape (batch_size, n_input)
    return tf.split(0, n_steps, tensor)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号