trainer.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:chainer-cifar 作者: dsanno 项目源码 文件源码
def __init__(self, net, optimizer, epoch_num=100, batch_size=100, device_id=-1, lr_shape='multistep', lr_decay=[0]):
        self.net = net
        self.optimizer = optimizer
        self.epoch_num = epoch_num
        self.batch_size = batch_size
        self.device_id = device_id
        if hasattr(optimizer, 'alpha'):
            self.initial_lr = optimizer.alpha
        else:
            self.initial_lr = optimizer.lr
        self.lr_shape = lr_shape
        self.lr_decay = lr_decay
        if device_id >= 0:
            self.xp = cuda.cupy
            self.net.to_gpu(device_id)
        else:
            self.xp = np
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号