def encode(self,
data: mx.sym.Symbol,
data_length: Optional[mx.sym.Symbol],
seq_len: int) -> Tuple[mx.sym.Symbol, mx.sym.Symbol, int]:
"""
Encodes data given sequence lengths of individual examples and maximum sequence length.
:param data: Input data.
:param data_length: Vector with sequence lengths.
:param seq_len: Maximum sequence length.
:return: Encoded versions of input data (data, data_length, seq_len).
"""
with mx.AttrScope(__layout__=C.TIME_MAJOR):
return mx.sym.swapaxes(data=data, dim1=0, dim2=1), data_length, seq_len
评论列表
文章目录