def test_seq_data_mask(self):
mask_cache_key = str(id(self.model.input)) + '_' + str(id(None))
mask_tensor = self.model._output_mask_cache[mask_cache_key]
mask = mask_tensor.eval(
session=K.get_session(),
feed_dict={self.model.input: self.seq_data}
)
self.assertTrue(
np.all(
mask[:, :self.seq_data_max_length]
)
)
self.assertFalse(
np.any(
mask[:, self.seq_data_max_length:]
)
)
评论列表
文章目录