load_data_sets.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:MuGo 作者: brilee 项目源码 文件源码
def read(filename):
        with gzip.open(filename, "rb") as f:
            header_bytes = f.read(CHUNK_HEADER_SIZE)
            data_size, board_size, input_planes, is_test = struct.unpack(CHUNK_HEADER_FORMAT, header_bytes)

            position_dims = data_size * board_size * board_size * input_planes
            next_move_dims = data_size * board_size * board_size

            # the +7 // 8 compensates for numpy's bitpacking padding
            packed_position_bytes = f.read((position_dims + 7) // 8)
            packed_next_move_bytes = f.read((next_move_dims + 7) // 8)
            # should have cleanly finished reading all bytes from file!
            assert len(f.read()) == 0

            flat_position = np.unpackbits(np.fromstring(packed_position_bytes, dtype=np.uint8))[:position_dims]
            flat_nextmoves = np.unpackbits(np.fromstring(packed_next_move_bytes, dtype=np.uint8))[:next_move_dims]

            pos_features = flat_position.reshape(data_size, board_size, board_size, input_planes)
            next_moves = flat_nextmoves.reshape(data_size, board_size * board_size)

        return DataSet(pos_features, next_moves, [], is_test=is_test)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号