def __call__(self, s_embed, s_src_pwr=None, s_mix_pwr=None, s_embed_flat=None):
with tf.variable_scope(self.name):
v_anchors = tf.get_variable(
'anchors', [hparams.NUM_ANCHOR, hparams.EMBED_SIZE],
initializer=tf.random_normal_initializer(
stddev=1.))
# all combinations of anchors
s_anchor_sets = ops.combinations(
v_anchors, hparams.MAX_N_SIGNAL)
# equation (6)
s_anchor_assignment = tf.einsum(
'btfe,pce->bptfc',
s_embed, s_anchor_sets)
s_anchor_assignment = tf.nn.softmax(s_anchor_assignment)
# equation (7)
s_attractor_sets = tf.einsum(
'bptfc,btfe->bpce',
s_anchor_assignment, s_embed)
s_attractor_sets /= tf.expand_dims(
tf.reduce_sum(s_anchor_assignment, axis=(2,3)), -1)
# equation (8)
s_in_set_similarities = tf.reduce_max(
tf.matmul(
s_attractor_sets,
tf.transpose(s_attractor_sets, [0, 1, 3, 2])),
axis=(-1, -2))
# equation (9)
s_subset_choice = tf.argmin(s_in_set_similarities, axis=1)
s_subset_choice = tf.transpose(tf.stack([
tf.range(hparams.BATCH_SIZE, dtype=tf.int64),
s_subset_choice]))
s_attractors = tf.gather_nd(s_attractor_sets, s_subset_choice)
if hparams.DEBUG:
self.debug_fetches = dict(
asets=s_attractor_sets,
anchors=v_anchors,
subset_choice=s_subset_choice)
return s_attractors
评论列表
文章目录