def setUp(self):
super(AttentiveReadTest, self).setUp()
self._batch_size = 3
self._memory_size = 4
self._memory_word_size = 1
self._query_word_size = 2
self._memory = tf.reshape(
tf.cast(tf.range(0, 3 * 4 * 1), dtype=tf.float32), shape=[3, 4, 1])
self._query = tf.reshape(
tf.cast(tf.range(0, 3 * 2), dtype=tf.float32), shape=[3, 2])
self._memory_mask = tf.convert_to_tensor(
[
[True, True, True, True],
[True, True, True, False],
[True, True, False, False],
],
dtype=tf.bool)
self._attention_logit_mod = ConstantZero()
self._attention_mod = snt.AttentiveRead(self._attention_logit_mod)
评论列表
文章目录