traintest.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:LSTMVAE 作者: ashwatthaman 项目源码 文件源码
def test(args,encdec,model_name,categ_arr=[],predictFlag=False):
    serializers.load_npz(model_name,encdec)
    if args.gpu>=0:
        import cupy as cp
        global xp;xp=cp
        encdec.to_gpu()
    encdec.setBatchSize(args.batchsize)

    if "cvae" in model_name:
        for categ in categ_arr:
            print("categ:{}".format(encdec.categ_vocab.itos(categ)))
            if predictFlag:
                encdec.predict(args.batchsize,tag=categ,randFlag=False)
    elif predictFlag:
        encdec.predict(args.batchsize,randFlag=False)
    return encdec
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号