def __call__(self, s_embed, s_src_pwr, s_mix_pwr, s_embed_flat=None):
if s_embed_flat is None:
s_embed_flat = tf.reshape(
s_embed,
[hparams.BATCH_SIZE, -1, hparams.EMBED_SIZE])
with tf.variable_scope(self.name):
s_src_assignment = tf.argmax(s_src_pwr, axis=1)
s_indices = tf.reshape(
s_src_assignment,
[hparams.BATCH_SIZE, -1])
fn_segmean = lambda _: tf.unsorted_segment_sum(
_[0], _[1], hparams.MAX_N_SIGNAL)
s_attractors = tf.map_fn(
fn_segmean, (s_embed_flat, s_indices), hparams.FLOATX)
s_attractors_wgt = tf.map_fn(
fn_segmean, (tf.ones_like(s_embed_flat), s_indices),
hparams.FLOATX)
s_attractors /= (s_attractors_wgt + 1.)
if hparams.DEBUG:
self.debug_fetches = dict()
# float[B, C, E]
return s_attractors
评论列表
文章目录