layers.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:tsnet 作者: coxlab 项目源码 文件源码
def collapse(T, W, divisive=False):

    if divisive: W = W / np.sum(np.square(W.reshape(W.shape[0], -1)), 1)[:,None,None,None]

    if T.shape[-6] == W.shape[0]: # Z ONLY (after 2nd-stage expansion)

        W = np.reshape (W, (1,)*(T.ndim-6) + (W.shape[0],1,1) + W.shape[1:])
        T = ne.evaluate('T*W')
        T = np.reshape (T, T.shape[:-3] + (np.prod(T.shape[-3:]),))
        T = np.sum(T, -1)

    else: # X ONLY (conv, before 2nd-stage expansion)

        T = np.squeeze  (T, -6)
        T = np.tensordot(T, W, ([-3,-2,-1], [1,2,3]))
        T = np.rollaxis (T, -1, 1)

    return T
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号