MultiModelAll.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:PyTorchText 作者: chenyuntc 项目源码 文件源码
def __init__(self, opt ):
        super(MultiModelAll, self).__init__()
        self.model_name = 'MultiModelAll'
        self.opt=opt
        # self.char_models = []
        self.models = []
        self.word_embedding=nn.Embedding(411720,256)
        self.char_embedding=nn.Embedding(11973,256)
        if opt.embedding_path:
            self.word_embedding.weight.data.copy_(t.from_numpy(np.load(opt.embedding_path.replace('char','word'))['vector']))
            self.char_embedding.weight.data.copy_(t.from_numpy(np.load(opt.embedding_path.replace('word','char'))['vector']))

        for _name,_path in zip(opt.model_names, opt.model_paths):
            tmp_config = Config().parse(opt.state_dict(),print_=False)
            tmp_config.embedding_path=None
            _model = getattr(models,_name)(tmp_config)
            # ?????????
            if _path is not None:
                _model.load(_path)
            # ??????????embedding??
            _model.encoder=(self.char_embedding if _model.opt.type_=='char' else self.word_embedding)
            self.models.append(_model)
        self.models = nn.ModuleList(self.models)

        self.model_num = len(self.models)
        self.weights = nn.Parameter(t.ones(opt.num_classes,self.model_num))
        assert self.opt.loss=='bceloss'
        # self.weight =[nn.Parameter(t.ones(self.model_num)/self.model_num) for _ in range(self.model_num)]
        # self.label_weight = nn.Parameter(t.eye(opt.num_classes))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号