traintest.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:LSTMVAE 作者: ashwatthaman 项目源码 文件源码
def train(args,encdec,model_name_base = "./{}/model/cvaehidden_kl_{}_{}_l{}.npz"):
    encdec.loadModel(model_name_base,args)
    if args.gpu >= 0:
        import cupy as cp
        global xp;
        xp = cp
        encdec.to_gpu()

    optimizer = optimizers.Adam()
    optimizer.setup(encdec)
    for e_i in range(encdec.epoch_now, args.epoch):
        encdec.setEpochNow(e_i)
        loss_sum = 0
        for tupl in encdec.getBatchGen(args):
            loss = encdec(tupl)
            loss_sum += loss.data

            encdec.cleargrads()
            loss.backward()
            optimizer.update()
        print("epoch{}:loss_sum:{}".format(e_i, loss_sum))
        model_name = model_name_base.format(args.dataname, args.dataname, e_i, args.n_latent)
        serializers.save_npz(model_name, encdec)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号