def __init__(self,
modelprefix,
imagepath,
inputshape,
labelpath,
epoch=0,
format='NCHW'):
self.modelprefix = modelprefix
self.imagepath = imagepath
self.labelpath = labelpath
self.inputshape = inputshape
self.epoch = epoch
self.format = format
with open(labelpath, 'r') as fo:
self.labels = [l.rstrip() for l in fo]
sym, arg_params, aux_params = mx.model.load_checkpoint(self.modelprefix, self.epoch)
self.mod = mx.mod.Module(symbol=sym, context=mx.gpu(), label_names=None)
self.mod.bind(for_training=False,
data_shapes=[('data', (1, self.inputshape[0], self.inputshape[1], self.inputshape[2]))],
label_shapes=self.mod._label_shapes)
self.mod.set_params(arg_params, aux_params, allow_missing=True)
评论列表
文章目录