def test(args,encdec,model_name,categ_arr=[],predictFlag=False):
serializers.load_npz(model_name,encdec)
if args.gpu>=0:
import cupy as cp
global xp;xp=cp
encdec.to_gpu()
encdec.setBatchSize(args.batchsize)
if "cvae" in model_name:
for categ in categ_arr:
print("categ:{}".format(encdec.categ_vocab.itos(categ)))
if predictFlag:
encdec.predict(args.batchsize,tag=categ,randFlag=False)
elif predictFlag:
encdec.predict(args.batchsize,randFlag=False)
return encdec
评论列表
文章目录