def forward(self, tensors: List[torch.Tensor], # pylint: disable=arguments-differ
mask: torch.Tensor = None) -> torch.Tensor:
"""
Compute a weighted average of the ``tensors``. The input tensors an be any shape
with at least two dimensions, but must all be the same shape.
When ``do_layer_norm=True``, the ``mask`` is required input. If the ``tensors`` are
dimensioned ``(dim_0, ..., dim_{n-1}, dim_n)``, then the ``mask`` is dimensioned
``(dim_0, ..., dim_{n-1})``, as in the typical case with ``tensors`` of shape
``(batch_size, timesteps, dim)`` and ``mask`` of shape ``(batch_size, timesteps)``.
When ``do_layer_norm=False`` the ``mask`` is ignored.
"""
if len(tensors) != self.mixture_size:
raise ConfigurationError("{} tensors were passed, but the module was initialized to "
"mix {} tensors.".format(len(tensors), self.mixture_size))
def _do_layer_norm(tensor, broadcast_mask, num_elements_not_masked):
tensor_masked = tensor * broadcast_mask
mean = torch.sum(tensor_masked) / num_elements_not_masked
variance = torch.sum(((tensor_masked - mean) * broadcast_mask)**2) / num_elements_not_masked
return (tensor - mean) / torch.sqrt(variance + 1E-12)
normed_weights = torch.nn.functional.softmax(torch.cat([parameter for parameter
in self.scalar_parameters]), dim=0)
normed_weights = torch.split(normed_weights, split_size=1)
if not self.do_layer_norm:
pieces = []
for weight, tensor in zip(normed_weights, tensors):
pieces.append(weight * tensor)
return self.gamma * sum(pieces)
else:
mask_float = mask.float()
broadcast_mask = mask_float.unsqueeze(-1)
input_dim = tensors[0].size(-1)
num_elements_not_masked = torch.sum(mask_float) * input_dim
pieces = []
for weight, tensor in zip(normed_weights, tensors):
pieces.append(weight * _do_layer_norm(tensor,
broadcast_mask, num_elements_not_masked))
return self.gamma * sum(pieces)
评论列表
文章目录