chord_relative.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:lstmprovisor-python 作者: Impro-Visor 项目源码 文件源码
def decode_to_probs(self, activations, relative_position, low_bound, high_bound):
        squashed = T.reshape(activations, (-1,self.RAW_ENCODING_WIDTH))
        n_parallel = squashed.shape[0]
        probs = T.nnet.softmax(squashed)


        def _scan_fn(cprobs, cpos):

            if self.with_artic:
                abs_probs = cprobs[:2]
                rel_probs = cprobs[2:]
            else:
                rel_probs = cprobs
                abs_probs = T.ones((2,))

            aligned = T.roll(rel_probs, (cpos-low_bound)%12)

            num_tile = int(math.ceil((high_bound-low_bound)/self.WINDOW_SIZE))

            tiled = T.tile(aligned, (num_tile,))[:(high_bound-low_bound)]

            full = T.concatenate([abs_probs, tiled], 0)
            return full

        # probs = theano.printing.Print("probs",['shape'])(probs)
        # relative_position = theano.printing.Print("relative_position",['shape'])(relative_position)
        from_scan, _ = theano.map(fn=_scan_fn, sequences=[probs, T.flatten(relative_position)])
        # from_scan = theano.printing.Print("from_scan",['shape'])(from_scan)
        newshape = T.concatenate([activations.shape[:-1],[2+high_bound-low_bound]],0)
        fixed = T.reshape(from_scan, newshape, ndim=activations.ndim)
        return fixed
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号