def __init__(self, data_arrays, time_steps, batch_size, nfeatures,
total_iterations=None):
self.batch_size = batch_size
self.time_steps = time_steps
self.nfeatures = nfeatures
self.index = 0
# make sure input is in dict format
if isinstance(data_arrays, dict):
self.data_arrays = dict(data_arrays)
else:
raise ValueError("Must provide dict as input")
# number of examples
self.ndata = len(self.data_arrays['inp_txt'])
# number of examples (as integer multiple of batch size)
self.ndata = self.ndata // (self.batch_size) * self.batch_size
self.nbatches = self.ndata // self.batch_size
if self.ndata < self.batch_size:
raise ValueError('Number of examples is smaller than the batch size')
self.total_iterations = self.nbatches if total_iterations is None else total_iterations
# reshape array for batch and batch size dimensions
self.data_arrays['inp_txt'] = \
self.data_arrays['inp_txt'][:self.ndata][:][:].reshape(
self.batch_size,
self.nbatches,
self.time_steps,
self.nfeatures)
self.data_arrays['tgt_txt'] = \
self.data_arrays['tgt_txt'][:self.ndata][:].reshape(
self.batch_size,
self.nbatches,
self.time_steps)
self.data_arrays['teacher_tgt'] = \
self.data_arrays['teacher_tgt'][:self.ndata][:][:].reshape(
self.batch_size,
self.nbatches,
self.time_steps,
self.nfeatures)
# Teacher Forcing
self.data_arrays['teacher_tgt'] = np.roll(self.data_arrays['teacher_tgt'], shift=1, axis=2)
# put a start token (0, 0) as the first decoder input
for i in range(self.batch_size):
for j in range(self.nbatches):
for k in range(self.nfeatures):
np.put(self.data_arrays['teacher_tgt'][i][j][0], [k], [0])
评论列表
文章目录