def decode(self, memories, keys, num_keys=None):
keys = self._normalize(keys)
num_memories = memories.get_shape().as_list()
num_memories[0] = self.num_models if num_memories[0] is None else num_memories[0]
num_keys = keys.get_shape().as_list()[0] if num_keys is None else num_keys
print 'decode: numkeys = ', num_keys, ' | num_memories = ', num_memories
# re-gather keys to avoid mixing between different keys.
perms = self.perm_keys(keys, self.perms, num_keys=num_keys)
pshp = perms.get_shape().as_list()
pshp[0] = num_keys*self.num_models if pshp[0] is None else pshp[0]
pshp[1] = num_memories[1] if pshp[1] is None else pshp[1]
permed_keys = tf.concat(0, [tf.strided_slice(perms, [i, 0], pshp, [num_keys, 1])
for i in range(num_keys)])
print 'memories = ', num_memories, \
'| dec_perms =', permed_keys.get_shape().as_list()
return self.conv_func(memories, permed_keys,
num_memories[0],
self.num_models,
num_keys=num_keys*self.num_models,
conj=True)
评论列表
文章目录