def __iter__(self):
while(True):
buser = []
bitem = []
brate = []
if (not linecache.getline(self.fname, self.index_start + self.batch_size)):
return
for i in range(self.index_start, self.index_start + self.batch_size):
line = linecache.getline(self.fname, i)
lines = line.strip().split('::')
if(len(lines) != 4):
continue
line_user, line_item, line_rate, _ = lines
buser.append(line_user)
bitem.append(line_item)
brate.append(line_rate)
data_all = [mx.nd.array(buser), mx.nd.array(bitem)]
label_all = [mx.nd.array(brate)]
data_names = ['user', 'item']
label_names = ['rate']
self.index_start += self.batch_size
data_batch = Batch(data_names, data_all, label_names, label_all)
yield data_batch
评论列表
文章目录