def collapse(T, W, divisive=False):
if divisive: W = W / np.sum(np.square(W.reshape(W.shape[0], -1)), 1)[:,None,None,None]
if T.shape[-6] == W.shape[0]: # Z ONLY (after 2nd-stage expansion)
W = np.reshape (W, (1,)*(T.ndim-6) + (W.shape[0],1,1) + W.shape[1:])
T = ne.evaluate('T*W')
T = np.reshape (T, T.shape[:-3] + (np.prod(T.shape[-3:]),))
T = np.sum(T, -1)
else: # X ONLY (conv, before 2nd-stage expansion)
T = np.squeeze (T, -6)
T = np.tensordot(T, W, ([-3,-2,-1], [1,2,3]))
T = np.rollaxis (T, -1, 1)
return T
评论列表
文章目录