def call(self, inputs, mask=None):
input_shape = K.int_shape(inputs)
outputs = self.layer.call(inputs)
outputs = K.permute_dimensions(
outputs,
self.permute_pattern + [len(input_shape) - 1]
)
outputs_shape = self.compute_output_shape(input_shape)
outputs = K.reshape(
outputs,
(-1, outputs_shape[1], outputs_shape[2])
)
mask_tensor = self.compute_mask(
inputs,
mask
)
mask_tensor = K.cast(mask_tensor, K.floatx())
mask_tensor = K.expand_dims(mask_tensor)
mask_output = K.repeat_elements(
mask_tensor,
outputs_shape[2],
2
)
return outputs * mask_output
评论列表
文章目录