def __init__(self, traces, sort=True):
self.batch = traces
self.length = len(traces)
self.traces_lengths = []
self.traces_max_length = 0
self.observes_max_length = 0
sb = {}
for trace in traces:
if trace.length is None:
util.logger.log('Batch: Received a trace of length zero.')
if trace.length > self.traces_max_length:
self.traces_max_length = trace.length
if trace.observes_tensor.size(0) > self.observes_max_length:
self.observes_max_length = trace.observes_tensor.size(0)
h = hash(trace.addresses_suffixed())
if not h in sb:
sb[h] = []
sb[h].append(trace)
self.sub_batches = []
for _, t in sb.items():
self.sub_batches.append(t)
if sort:
# Sort the batch in decreasing trace length.
self.batch = sorted(self.batch, reverse=True, key=lambda t: t.length)
self.traces_lengths = [t.length for t in self.batch]
评论列表
文章目录