def create_batches(self):
self.num_batches = int(self.train.size / (self.batch_size * self.seq_length))
self.num_valid_batches = int(self.valid.size / (self.batch_size * self.seq_length))
# When the data (tensor) is too small, let's give them a better error message
if self.num_batches == 0:
assert False, "Not enough data. Make seq_length and batch_size small."
self.train = self.train[:self.num_batches * self.batch_size * self.seq_length]
self.valid = self.valid[:self.num_valid_batches * self.batch_size * self.seq_length]
xdata = self.train
ydata = np.copy(self.train)
ydata[:-1] = xdata[1:]
ydata[-1] = xdata[0]
x_valid = self.valid
y_valid = np.copy(self.valid)
y_valid[:-1] = x_valid[1:]
y_valid[-1] = x_valid[0]
self.x_valid = np.split(x_valid.reshape(self.batch_size, -1), self.num_valid_batches, 1)
self.y_valid = np.split(y_valid.reshape(self.batch_size, -1), self.num_valid_batches, 1)
self.x_batches = np.split(xdata.reshape(self.batch_size, -1), self.num_batches, 1)
self.y_batches = np.split(ydata.reshape(self.batch_size, -1), self.num_batches, 1)
评论列表
文章目录