segment.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:torch_light 作者: ne7ermore 项目源码 文件源码
def __init__(self, model_source="model", cuda=False):
        self.torch = torch.cuda if cuda else torch
        self.cuda = cuda
        if self.cuda:
            model_source = torch.load(model_source)
        else:
            model_source = torch.load(model_source, map_location=lambda storage, loc: storage)

        self.src_dict = model_source["src_dict"]
        self.trains_score = model_source["trains_score"]
        self.args = args = model_source["settings"]

        model = BiLSTM_Cut(args)
        model.load_state_dict(model_source['model'])

        if self.cuda:
            model = model.cuda()
            model.prob_projection = nn.Softmax().cuda()
        else:
            model = model.cpu()
            model.prob_projection = nn.Softmax().cpu()

        self.model = model.eval()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号