def get_padded_seq_lengths(padded):
"""Returns the number of (seq_len) non-nan elements per sequence.
:param padded: 2d or 3d tensor with dim 2 the time dimension
"""
if len(padded.shape) == 2:
# (n_seqs,n_timesteps)
seq_lengths = np.count_nonzero(~np.isnan(padded), axis=1)
elif len(padded.shape) == 3:
# (n_seqs,n_timesteps,n_features,..)
seq_lengths = np.count_nonzero(~np.isnan(padded[:, :, 0]), axis=1)
else:
print('not yet implemented')
# TODO
return seq_lengths
评论列表
文章目录