def gen_batches(data, n_seqs, n_steps):
"""Create a generator that returns batches of size n_seqs x n_steps."""
characters_per_batch = n_seqs * n_steps
n_batches = len(data) // characters_per_batch
# Keep only enough characters to make full batches
data = data[:n_batches*characters_per_batch]
data = data.reshape([n_seqs, -1])
for n in range(0, data.shape[1], n_steps):
x = data[:, n:n+n_steps]
y = np.zeros_like(x)
y[:, :-1], y[:, -1] = x[:, 1:], x[:, 0]
yield x, y
#-------------------------------------------------------------------------------
# Parse commandline
#-------------------------------------------------------------------------------
评论列表
文章目录