def next_batch(self, batch_size = 50) :
start = self._index_in_epoch
if ( self._epochs_completed == 0 ) and ( start == 0 ) :
self.batch_size = batch_size
while np.modf(float(self._ndata)/self.batch_size)[0] > 0.0 :
print 'Warning! Number of data/ batch size must be an integer.'
print 'number of data: %d' % self._ndata
print 'batch size: %d' % self.batch_size
self.batch_size = int(input('Input new batch size: '))
print 'batch size : %d' % self.batch_size
print 'number of data: %d' % self._ndata
self._index_in_epoch += self.batch_size
if self._index_in_epoch > self._ndata :
# Number of training epochs completed
self._epochs_completed += 1
# Shuffle data
random.shuffle(self.shuffle_index)
self._images = self._images[self.shuffle_index]
self._labels = self._labels[self.shuffle_index]
# Reinitialize conunter
start = 0
self._index_in_epoch = self.batch_size
assert self.batch_size <= self._ndata
end = self._index_in_epoch
return self._images[start:end], self._labels[start:end]
评论列表
文章目录