embedding_model.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:source_separation_ml_jeju 作者: hjkwon0609 项目源码 文件源码
def add_loss_op(self, voice_spec, song_spec):
        if not EmbeddingConfig.use_vpnn:
            # concatenate all batches into one axis  [num_batches * time_frames, freq_bins]
            voice_spec = tf.reshape(voice_spec, [-1, EmbeddingConfig.num_freq_bins])
            song_spec = tf.reshape(song_spec, [-1, EmbeddingConfig.num_freq_bins])

        self.voice_spec = voice_spec  # for output
        self.song_spec = song_spec

        song_spec_mask = tf.cast(tf.abs(song_spec) > tf.abs(voice_spec), tf.float32)
        voice_spec_mask =  tf.ones(song_spec_mask.get_shape()) - song_spec_mask

        V = self.embedding
        Y = tf.transpose([song_spec_mask, voice_spec_mask], [1, 2, 0])  # [num_batch, num_freq_bins, 2]

        # A_pred = tf.matmul(V, tf.transpose(V, [0, 2, 1]))
        # A_target = tf.matmul(Y, tf.transpose(Y, [0, 2, 1]))
        error = tf.reduce_mean(tf.square(tf.matmul(V, tf.transpose(V, [0, 2, 1])) - tf.matmul(Y, tf.transpose(Y, [0, 2, 1]))))  # average error per TF bin

        # tf.summary.histogram('a_same cluster embedding distribution', A_pred * A_target)
        # tf.summary.histogram('a_different cluster embedding distribution', A_pred * (1 - A_target))

        # tf.summary.histogram('V', V)
        # tf.summary.histogram('V V^T', A_pred)

        l2_cost = tf.reduce_sum([tf.norm(v) for v in tf.trainable_variables() if len(v.get_shape().as_list()) == 2])

        self.loss = EmbeddingConfig.l2_lambda * l2_cost + error

        # tf.summary.scalar("avg_loss", self.loss)
        # tf.summary.scalar('regularizer cost', EmbeddingConfig.l2_lambda * l2_cost)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号