def __init__(self, data, batch_size, big_batch=10):
self.data_inp = [item + [_EOS_ID] for item in data['input']]
self.data_spc = data['speaker_code']
self.data_out = data['output']
self.samples = len(self.data_inp)
self.batch_size = batch_size
self.big_batch = big_batch
self.big_batch_step = 0
if batch_size == 'all':
self.batch_size = self.samples
self.big_batch = 1
if self.samples < self.big_batch * self.batch_size:
self.max_batch_step = ceil(self.samples / (self.big_batch * self.batch_size))
else:
self.max_batch_step = floor(self.samples / (self.big_batch * self.batch_size)) - 1
self.run_through = 0
self.perm = None
self.perm_index = np.arange(min(self.big_batch * self.batch_size, self.samples))
self.batch = None
评论列表
文章目录