def prepare_message(self, target_features, source_features, select_mat, gate_module):
feature_data = []
transfer_list = np.where(select_mat > 0)
source_indices = Variable(torch.from_numpy(transfer_list[1]).type(torch.LongTensor)).cuda()
target_indices = Variable(torch.from_numpy(transfer_list[0]).type(torch.LongTensor)).cuda()
source_f = torch.index_select(source_features, 0, source_indices)
target_f = torch.index_select(target_features, 0, target_indices)
transferred_features = gate_module(target_f, source_f)
for f_id in range(target_features.size()[0]):
if len(np.where(select_mat[f_id, :] > 0)[0]) > 0:
feature_indices = np.where(transfer_list[0] == f_id)[0]
indices = Variable(torch.from_numpy(feature_indices).type(torch.LongTensor)).cuda()
features = torch.index_select(transferred_features, 0, indices).mean(0).view(-1)
feature_data.append(features)
else:
temp = Variable(torch.zeros(target_features.size()[1:]), requires_grad=True).type(torch.FloatTensor).cuda()
feature_data.append(temp)
return torch.stack(feature_data, 0)
评论列表
文章目录