def split_in_chunks(minibatch, num_splits, flatten_keys=['labels']):
'''Return the splits per device
Return a list of dictionaries, one per device. Each dictionary
contains, for each key, the values that should be allocated on its
device.
'''
# Split the value of each key into chunks
for k, v in minibatch.iteritems():
minibatch[k] = np.array_split(v, num_splits)
if any(k == v for v in flatten_keys):
minibatch[k] = [el.flatten() for el in minibatch[k]]
return map(dict, zip(*[[(k, v) for v in value]
for k, value in minibatch.items()]))
评论列表
文章目录