def BatchedDot(x, y, last_axis=False):
if last_axis==False:
return T.batched_dot(x, y)
elif last_axis:
if x.ndim == 2:
shuffled_x = x.dimshuffle(1,0)
elif x.ndim == 3:
shuffled_x = x.dimshuffle(2,0,1)
elif x.ndim == 4:
shuffled_x = x.dimshuffle(3,0,1,2)
else:
raise ValueError('BatchedDot inputs must have between 2-4 dimensions, but x has ' + str(x.ndim) + ' dimensions')
if y.ndim == 2:
shuffled_y = y.dimshuffle(1,0)
elif y.ndim == 3:
shuffled_y = y.dimshuffle(2,0,1)
elif y.ndim == 4:
shuffled_y = y.dimshuffle(3,0,1,2)
else:
raise ValueError('BatchedDot inputs must have between 2-4 dimensions, but y has ' + str(y.ndim) + ' dimensions')
dot = T.batched_dot(shuffled_x, shuffled_y)
if dot.ndim == 2:
return dot.dimshuffle(1,0)
elif dot.ndim == 3:
return dot.dimshuffle(1,2,0)
elif dot.ndim == 4:
return dot.dimshuffle(1,2,3,0)
评论列表
文章目录