wavenet.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:wavenet 作者: musyoku 项目源码 文件源码
def cross_entropy(self, raw_network_output, target_signal_data):
        if isinstance(target_signal_data, Variable):
            raise Exception("target_signal_data cannot be Variable")

        raw_network_output = self.to_variable(raw_network_output)
        target_width = target_signal_data.shape[1]
        batchsize = raw_network_output.data.shape[0]

        if raw_network_output.data.shape[3] != target_width:
            raise Exception("raw_network_output.width != target.width")

        # (batchsize * time_step,) <- (batchsize, time_step)
        target_signal_data = target_signal_data.reshape((-1,))
        target_signal = self.to_variable(target_signal_data)

        # (batchsize * time_step, channels) <- (batchsize, channels, 1, time_step)
        raw_network_output = F.transpose(raw_network_output, (0, 3, 2, 1))
        raw_network_output = F.reshape(raw_network_output, (batchsize * target_width, -1))

        loss = F.softmax_cross_entropy(raw_network_output, target_signal)
        return loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号