def transform_block(tensor):
# Prepare data shape to match `rnn` function requirements
# Current data input shape: (batch_size, n_steps, n_input)
# Required shape: 'n_steps' tensors list of shape (batch_size, n_input)
# Permuting batch_size and n_steps
tensor = tf.transpose(tensor, [1, 0, 2])
# Reshaping to (n_steps*batch_size, n_input)
tensor = tf.reshape(tensor, [-1, n_input])
# Split to get a list of 'n_steps' tensors of shape (batch_size, n_input)
return tf.split(0, n_steps, tensor)
评论列表
文章目录