def splittensor(axis=1, ratio_split=1, id_split=0, **kwargs):
def f(X):
div = X.shape[axis] // ratio_split
if axis == 0:
output = X[id_split * div:(id_split + 1) * div, :, :, :]
elif axis == 1:
output = X[:, id_split * div:(id_split + 1) * div, :, :]
elif axis == 2:
output = X[:, :, id_split * div:(id_split + 1) * div, :]
elif axis == 3:
output = X[:, :, :, id_split * div:(id_split + 1) * div]
else:
raise ValueError('This axis is not possible')
return output
def g(input_shape):
output_shape = list(input_shape)
output_shape[axis] = output_shape[axis] // ratio_split
return tuple(output_shape)
return Lambda(f, output_shape=lambda input_shape: g(input_shape), **kwargs)
评论列表
文章目录