def _get_batch_from_indices(self, indices):
"""Given a list of indices, return the potentially augmented batch."""
x_batch = []
seq_len = []
x_labels = []
for idx in range(len(indices)):
i = indices[idx]
data = self.random_scale(self.strokes[i])
data_copy = np.copy(data)
if self.augment_stroke_prob > 0:
data_copy = augment_strokes(data_copy, self.augment_stroke_prob)
x_batch.append(data_copy)
length = len(data_copy)
seq_len.append(length)
x_labels.append(self.labels[i])
seq_len = np.array(seq_len, dtype=int)
# We return three things: stroke-3 format, stroke-5 format, list of seq_len.
return x_batch, x_labels, self.pad_batch(x_batch, self.max_seq_length), seq_len
utils_class.py 文件源码
python
阅读 28
收藏 0
点赞 0
评论 0
评论列表
文章目录