def preprocess(self, strokes):
"""Remove entries from strokes having > max_seq_length points."""
raw_data = []
seq_len = []
count_data = 0
for i in range(len(strokes)):
data = strokes[i]
if len(data) <= (self.max_seq_length):
count_data += 1
# removes large gaps from the data
data = np.minimum(data, self.limit)
data = np.maximum(data, -self.limit)
data = np.array(data, dtype=np.float32)
data[:, 0:2] /= self.scale_factor
raw_data.append(data)
seq_len.append(len(data))
seq_len = np.array(seq_len) # nstrokes for each sketch
idx = np.argsort(seq_len)
self.strokes = []
for i in range(len(seq_len)):
self.strokes.append(raw_data[idx[i]])
print("total images <= max_seq_len is %d" % count_data)
self.num_batches = int(count_data / self.batch_size)
评论列表
文章目录