def __iter__(self):
while(True):
bdata = []
blabel = []
if (not linecache.getline(self.fname, self.index_start + self.batch_size)):
return
for i in range(self.index_start, self.index_start + self.batch_size):
line = linecache.getline(self.fname, i)
line_label, line_data = line.strip().split('\t',1)
blabel.append(line_label)
bdata.append(np.array(line_data.split('\t')))
data_all = [mx.nd.array(bdata)]
label_all = [mx.nd.array(blabel)]
data_names = ['data']
label_names = ['softmax_label']
self.index_start += self.batch_size
data_batch = Batch(data_names, data_all, label_names, label_all)
yield data_batch
评论列表
文章目录