def set_proposal_params(self, tensor_of_proposal_means_stds_coeffs):
n_components = int(tensor_of_proposal_means_stds_coeffs.size(0) / 3)
self.proposal_means, self.proposal_stds, self.proposal_coeffs = torch.split(tensor_of_proposal_means_stds_coeffs, n_components)
评论列表
文章目录