def splittensor(axis=1, ratio_split=1, id_split=0, **kwargs):
def f(X):
div = K.shape(X)[axis] // ratio_split
if axis == 0:
output = X[id_split*div:(id_split+1)*div, :, :, :]
elif axis == 1:
output = X[:, id_split*div:(id_split+1)*div, :, :]
elif axis == 2:
output = X[:, :, id_split*div:(id_split+1)*div, :]
elif axis == 3:
output = X[:, :, :, id_split*div:(id_split+1)*div]
else:
raise ValueError("This axis is not possible")
return output
def g(input_shape):
output_shape = list(input_shape)
output_shape[axis] = output_shape[axis] // ratio_split
return tuple(output_shape)
return Lambda(f, output_shape=lambda input_shape: g(input_shape), **kwargs)
评论列表
文章目录