eval.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:mlAlgorithms 作者: gu-yan 项目源码 文件源码
def __init__(self,
                 modelprefix,
                 imagepath,
                 inputshape,
                 labelpath,
                 epoch=0,
                 format='NCHW'):
        self.modelprefix = modelprefix
        self.imagepath = imagepath
        self.labelpath = labelpath
        self.inputshape = inputshape
        self.epoch = epoch
        self.format = format

        with open(labelpath, 'r') as fo:
            self.labels = [l.rstrip() for l in fo]

        sym, arg_params, aux_params = mx.model.load_checkpoint(self.modelprefix, self.epoch)
        self.mod = mx.mod.Module(symbol=sym, context=mx.gpu(), label_names=None)
        self.mod.bind(for_training=False,
                      data_shapes=[('data', (1, self.inputshape[0], self.inputshape[1], self.inputshape[2]))],
                      label_shapes=self.mod._label_shapes)
        self.mod.set_params(arg_params, aux_params, allow_missing=True)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号