ops.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:DaNet-Tensorflow 作者: khaotik 项目源码 文件源码
def pit_mse_loss(s_x, s_y, pit_axis=1, perm_size=None, name='pit_loss'):
    '''
    Permutation invariant MSE loss, batched version

    Args:
        s_x: tensor
        s_y: tensor
        pit_axis: which axis permutations occur
        perm_size: size of permutation, infer from tensor shape by default
        name: string

    Returns:
        s_loss, v_perms, s_loss_sets_idx

        s_loss: scalar loss
        v_perms: constant int matrix of permutations
        s_perm_sets_idx: int matrix, indicating selected permutations

    '''
    x_shp = s_x.get_shape().as_list()
    ndim = len(x_shp)

    batch_size = x_shp[0]
    if batch_size is None:
        batch_size = hparams.BATCH_SIZE

    assert -ndim <= pit_axis < ndim
    pit_axis %= ndim
    assert pit_axis != 0
    reduce_axes = [
        i for i in range(1, ndim+1) if i not in [pit_axis, pit_axis+1]]
    with tf.variable_scope(name):
        v_perms = tf.constant(
            list(itertools.permutations(range(hparams.MAX_N_SIGNAL))),
            dtype=hparams.INTX)
        s_perms_onehot = tf.one_hot(
            v_perms, hparams.MAX_N_SIGNAL, dtype=hparams.FLOATX)

        s_x = tf.expand_dims(s_x, pit_axis+1)
        s_y = tf.expand_dims(s_y, pit_axis)
        if s_x.dtype.is_complex and s_y.dtype.is_complex:
            s_diff = s_x - s_y
            s_cross_loss = tf.reduce_mean(
                tf.square(tf.real(s_diff)) + tf.square(tf.imag(s_diff)),
                reduce_axes)
        else:
            s_cross_loss = tf.reduce_mean(
                tf.squared_difference(s_x, s_y), reduce_axes)
        s_loss_sets = tf.einsum(
            'bij,pij->bp', s_cross_loss, s_perms_onehot)
        s_loss_sets_idx = tf.argmin(s_loss_sets, axis=1)
        s_loss = tf.gather_nd(
            s_loss_sets,
            tf.stack([
                tf.range(hparams.BATCH_SIZE, dtype=tf.int64),
                s_loss_sets_idx], axis=1))
        s_loss = tf.reduce_mean(s_loss)
    return s_loss, v_perms, s_loss_sets_idx
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号