def pit_mse_loss(s_x, s_y, pit_axis=1, perm_size=None, name='pit_loss'):
'''
Permutation invariant MSE loss, batched version
Args:
s_x: tensor
s_y: tensor
pit_axis: which axis permutations occur
perm_size: size of permutation, infer from tensor shape by default
name: string
Returns:
s_loss, v_perms, s_loss_sets_idx
s_loss: scalar loss
v_perms: constant int matrix of permutations
s_perm_sets_idx: int matrix, indicating selected permutations
'''
x_shp = s_x.get_shape().as_list()
ndim = len(x_shp)
batch_size = x_shp[0]
if batch_size is None:
batch_size = hparams.BATCH_SIZE
assert -ndim <= pit_axis < ndim
pit_axis %= ndim
assert pit_axis != 0
reduce_axes = [
i for i in range(1, ndim+1) if i not in [pit_axis, pit_axis+1]]
with tf.variable_scope(name):
v_perms = tf.constant(
list(itertools.permutations(range(hparams.MAX_N_SIGNAL))),
dtype=hparams.INTX)
s_perms_onehot = tf.one_hot(
v_perms, hparams.MAX_N_SIGNAL, dtype=hparams.FLOATX)
s_x = tf.expand_dims(s_x, pit_axis+1)
s_y = tf.expand_dims(s_y, pit_axis)
if s_x.dtype.is_complex and s_y.dtype.is_complex:
s_diff = s_x - s_y
s_cross_loss = tf.reduce_mean(
tf.square(tf.real(s_diff)) + tf.square(tf.imag(s_diff)),
reduce_axes)
else:
s_cross_loss = tf.reduce_mean(
tf.squared_difference(s_x, s_y), reduce_axes)
s_loss_sets = tf.einsum(
'bij,pij->bp', s_cross_loss, s_perms_onehot)
s_loss_sets_idx = tf.argmin(s_loss_sets, axis=1)
s_loss = tf.gather_nd(
s_loss_sets,
tf.stack([
tf.range(hparams.BATCH_SIZE, dtype=tf.int64),
s_loss_sets_idx], axis=1))
s_loss = tf.reduce_mean(s_loss)
return s_loss, v_perms, s_loss_sets_idx
评论列表
文章目录