mps_base.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:MSDN 作者: yikang-li 项目源码 文件源码
def prepare_message(self, target_features, source_features, select_mat, gate_module):
        feature_data = []

        transfer_list = np.where(select_mat > 0)
        source_indices = Variable(torch.from_numpy(transfer_list[1]).type(torch.LongTensor)).cuda()
        target_indices = Variable(torch.from_numpy(transfer_list[0]).type(torch.LongTensor)).cuda()
        source_f = torch.index_select(source_features, 0, source_indices)
        target_f = torch.index_select(target_features, 0, target_indices)
        transferred_features = gate_module(target_f, source_f)

        for f_id in range(target_features.size()[0]):
            if len(np.where(select_mat[f_id, :] > 0)[0]) > 0:
                feature_indices = np.where(transfer_list[0] == f_id)[0]
                indices = Variable(torch.from_numpy(feature_indices).type(torch.LongTensor)).cuda()
                features = torch.index_select(transferred_features, 0, indices).mean(0).view(-1)
                feature_data.append(features)
            else:
                temp = Variable(torch.zeros(target_features.size()[1:]), requires_grad=True).type(torch.FloatTensor).cuda()
                feature_data.append(temp)
        return torch.stack(feature_data, 0)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号