chainer_train.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:jrm_ssl 作者: Fhrozen 项目源码 文件源码
def data_feed():
    global max_len
    global batch_init
    num_networks = [int(x) for x in config.get('gpu', 'index').split(';')  ]
    num_networks = np.amax((len(num_networks),1)).astype(np.int)
    DBClass = importlib.import_module('python_utils.datareader.{}'.format(
        config.get('reader', 'data')))
    reader = getattr(DBClass,config.get('reader', 'class'))(config)
    idxs = reader.idxs
    max_len = len(idxs) - (len(idxs) % (batchsize*num_networks))
    data_q.put('train')
    if args.A == 'r': batch_init += batchsize
    for epoch in six.moves.range(init_epoch,1+epochs):
        shuffle(idxs)
        for idx in range (batch_init,max_len,batchsize*num_networks):
            data_batch = reader.read_data(idxs[idx:idx+batchsize], num_networks)
            data_q.put((epoch, idx, data_batch.copy()))
        batch_init = 0
    data_q.put('end')
    return
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号