def _slice2(_x, start, end):
'''Slightly different slice function than above.
Args:
_x (T.tensor).
start (int).
end (int).
Returns:
T.tensor.
'''
if _x.ndim == 1:
return _x[start:end]
elif _x.ndim == 2:
return _x[:, start:end]
elif _x.ndim == 3:
return _x[:, :, start:end]
elif _x.ndim == 4:
return _x[:, :, :, start:end]
else:
raise ValueError('Number of dims (%d) not supported'
' (but can add easily here)' % _x.ndim)
评论列表
文章目录