def visual(self, input_ts, target_ts, mask_ts, output_ts=None):
"""
input_ts: [(num_wordsx2+2) x batch_size x (len_word+2)]
target_ts: [(num_wordsx2+2) x batch_size x (len_word)]
mask_ts: [(num_wordsx2+2) x batch_size x (len_word)]
output_ts: [(num_wordsx2+2) x batch_size x (len_word)]
"""
output_ts = torch.round(output_ts * mask_ts) if output_ts is not None else None
input_strings = [self._readable(input_ts[:, 0, i]) for i in range(input_ts.size(2))]
target_strings = [self._readable(target_ts[:, 0, i]) for i in range(target_ts.size(2))]
mask_strings = [self._readable(mask_ts[:, 0, 0])]
output_strings = [self._readable(output_ts[:, 0, i]) for i in range(output_ts.size(2))] if output_ts is not None else None
input_strings = 'Input:\n' + '\n'.join(input_strings)
target_strings = 'Target:\n' + '\n'.join(target_strings)
mask_strings = 'Mask:\n' + '\n'.join(mask_strings)
output_strings = 'Output:\n' + '\n'.join(output_strings) if output_ts is not None else None
# strings = [input_strings, target_strings, mask_strings, output_strings]
# self.logger.warning(input_strings)
# self.logger.warning(target_strings)
# self.logger.warning(mask_strings)
# self.logger.warning(output_strings)
print(input_strings)
print(target_strings)
print(mask_strings)
print(output_strings) if output_ts is not None else None
评论列表
文章目录