def trim(tensor, width):
"""Trim the tensor on the -1 axis.
Trim a given tensor of shape `[..., in_width]` to a smaller tensor
of shape `[..., width]`, along the -1 axis. If the `width` argument
is greater or equal than the actual width of the tensor, no operation
is performed.
Arguments:
tensor: a 3D tf.Tensor of shape `[..., in_width]`.
width: a `int` representing the target value of the 3rd
dimension of the output tensor.
Returns:
a 3D tensor of shape `[..., width]` where the
third dimension is the minimum between the input width
and the value of the `width` argument.
Example:
```python
# t is a tensor like:
# [[[1, 1, 1],
[2, 2, 2],
[3, 3, 3]],
[7, 7, 7],
[8, 8, 8],
[9, 9, 9]]]
q = trim(t, 2)
# q is a tensor like:
# [[[1, 1],
[2, 2],
[3, 3]],
[7, 7],
[8, 8],
[9, 9]]]
"""
result = tf.cond(
tf.less_equal(tf.shape(tensor)[-1], width),
lambda: tensor,
lambda: _trim(tensor, width))
result.set_shape(tensor.get_shape().as_list()[:-1] + [width])
return result
```