def data_feed():
global max_len
global batch_init
num_networks = [int(x) for x in config.get('gpu', 'index').split(';') ]
num_networks = np.amax((len(num_networks),1)).astype(np.int)
DBClass = importlib.import_module('python_utils.datareader.{}'.format(
config.get('reader', 'data')))
reader = getattr(DBClass,config.get('reader', 'class'))(config)
idxs = reader.idxs
max_len = len(idxs) - (len(idxs) % (batchsize*num_networks))
data_q.put('train')
if args.A == 'r': batch_init += batchsize
for epoch in six.moves.range(init_epoch,1+epochs):
shuffle(idxs)
for idx in range (batch_init,max_len,batchsize*num_networks):
data_batch = reader.read_data(idxs[idx:idx+batchsize], num_networks)
data_q.put((epoch, idx, data_batch.copy()))
batch_init = 0
data_q.put('end')
return
评论列表
文章目录