def __init__(self, opt ):
super(MultiModelAll, self).__init__()
self.model_name = 'MultiModelAll'
self.opt=opt
# self.char_models = []
self.models = []
self.word_embedding=nn.Embedding(411720,256)
self.char_embedding=nn.Embedding(11973,256)
if opt.embedding_path:
self.word_embedding.weight.data.copy_(t.from_numpy(np.load(opt.embedding_path.replace('char','word'))['vector']))
self.char_embedding.weight.data.copy_(t.from_numpy(np.load(opt.embedding_path.replace('word','char'))['vector']))
for _name,_path in zip(opt.model_names, opt.model_paths):
tmp_config = Config().parse(opt.state_dict(),print_=False)
tmp_config.embedding_path=None
_model = getattr(models,_name)(tmp_config)
# ?????????
if _path is not None:
_model.load(_path)
# ??????????embedding??
_model.encoder=(self.char_embedding if _model.opt.type_=='char' else self.word_embedding)
self.models.append(_model)
self.models = nn.ModuleList(self.models)
self.model_num = len(self.models)
self.weights = nn.Parameter(t.ones(opt.num_classes,self.model_num))
assert self.opt.loss=='bceloss'
# self.weight =[nn.Parameter(t.ones(self.model_num)/self.model_num) for _ in range(self.model_num)]
# self.label_weight = nn.Parameter(t.eye(opt.num_classes))
评论列表
文章目录