def _aggregate_batch(data_holder, use_list=False):
size = len(data_holder[0])
result = []
for k in range(size):
if use_list:
result.append(
[x[k] for x in data_holder])
else:
dt = data_holder[0][k]
if type(dt) in [int, bool]:
tp = 'int32'
elif type(dt) == float:
tp = 'float32'
else:
try:
tp = dt.dtype
except Exception:
raise TypeError("Unsupported type to batch: {}"
.format(type(dt)))
try:
result.append(
np.asarray([x[k] for x in data_holder], dtype=tp))
except KeyboardInterrupt:
raise
except Exception:
logger.exception("Cannot batch data. Perhaps they are of "
"inconsistent shape?")
import IPython as IP
IP.embed(config=IP
.terminal # @UndefinedVariable
.ipapp.load_default_config())
return result
评论列表
文章目录