def __call__(self, s_embed, s_src_pwr, s_mix_pwr, s_embed_flat=None):
if s_embed_flat is None:
s_embed_flat = tf.reshape(
s_embed,
[hparams.BATCH_SIZE, -1, hparams.EMBED_SIZE])
with tf.variable_scope(self.name):
s_wgt = tf.reshape(
s_mix_pwr, [hparams.BATCH_SIZE, -1, 1])
s_src_assignment = tf.argmax(s_src_pwr, axis=1)
s_indices = tf.reshape(
s_src_assignment,
[hparams.BATCH_SIZE, -1])
fn_segmean = lambda _: tf.unsorted_segment_sum(
_[0], _[1], hparams.MAX_N_SIGNAL)
s_attractors = tf.map_fn(fn_segmean, (
s_embed_flat * s_wgt, s_indices),
hparams.FLOATX)
s_attractors_wgt = tf.map_fn(fn_segmean, (
s_wgt, s_indices),
hparams.FLOATX)
s_attractors /= (s_attractors_wgt + hparams.EPS)
if hparams.DEBUG:
self.debug_fetches = dict()
# float[B, C, E]
return s_attractors
评论列表
文章目录