_dataflow.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:keras_experiments 作者: avolkov1 项目源码 文件源码
def _aggregate_batch(data_holder, use_list=False):
        size = len(data_holder[0])
        result = []
        for k in range(size):
            if use_list:
                result.append(
                    [x[k] for x in data_holder])
            else:
                dt = data_holder[0][k]
                if type(dt) in [int, bool]:
                    tp = 'int32'
                elif type(dt) == float:
                    tp = 'float32'
                else:
                    try:
                        tp = dt.dtype
                    except Exception:
                        raise TypeError("Unsupported type to batch: {}"
                                        .format(type(dt)))
                try:
                    result.append(
                        np.asarray([x[k] for x in data_holder], dtype=tp))
                except KeyboardInterrupt:
                    raise
                except Exception:
                    logger.exception("Cannot batch data. Perhaps they are of "
                                     "inconsistent shape?")
                    import IPython as IP
                    IP.embed(config=IP
                             .terminal  # @UndefinedVariable
                             .ipapp.load_default_config())
        return result
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号