utils_class.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:sketch_rnn_classification 作者: payalbajaj 项目源码 文件源码
def _get_batch_from_indices(self, indices):
    """Given a list of indices, return the potentially augmented batch."""
    x_batch = []
    seq_len = []
    x_labels = []
    for idx in range(len(indices)):
      i = indices[idx]
      data = self.random_scale(self.strokes[i])
      data_copy = np.copy(data)
      if self.augment_stroke_prob > 0:
        data_copy = augment_strokes(data_copy, self.augment_stroke_prob)
      x_batch.append(data_copy)
      length = len(data_copy)
      seq_len.append(length)
      x_labels.append(self.labels[i])
    seq_len = np.array(seq_len, dtype=int)
    # We return three things: stroke-3 format, stroke-5 format, list of seq_len.
    return x_batch, x_labels, self.pad_batch(x_batch, self.max_seq_length), seq_len
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号