def __init__(self, model_source="model", cuda=False):
self.torch = torch.cuda if cuda else torch
self.cuda = cuda
if self.cuda:
model_source = torch.load(model_source)
else:
model_source = torch.load(model_source, map_location=lambda storage, loc: storage)
self.src_dict = model_source["src_dict"]
self.trains_score = model_source["trains_score"]
self.args = args = model_source["settings"]
model = BiLSTM_Cut(args)
model.load_state_dict(model_source['model'])
if self.cuda:
model = model.cuda()
model.prob_projection = nn.Softmax().cuda()
else:
model = model.cpu()
model.prob_projection = nn.Softmax().cpu()
self.model = model.eval()
评论列表
文章目录