def compute_mask(self, input, mask):
if mask is None or self.return_mode == "last_output":
return None
elif self.return_mode == "all_outputs":
return mask # (batch_size, input_length)
else:
# Return mode is output_and_memory
# Mask memory corresponding to all the inputs that are masked, and do not mask the output
# (batch_size, input_length + 1)
return K.cast(K.concatenate([K.zeros_like(mask[:, :1]), mask]), 'uint8')
评论列表
文章目录