def load(cls, f):
h = cls()
magic = f.read(8)
if len(magic) != 8 or magic != cls._MAGIC:
raise InvalidModelFormatError('invalid magic value: {0}'.format(str(magic)))
for (key, fmt, _) in cls.fields():
size = struct.calcsize(fmt)
raw = f.read(size)
if len(raw) != size:
raise InvalidModelFormatError('failed to read {0} in header (expected {1} bytes, got {2} bytes)'.format(key, size, len(raw)))
try:
value = struct.unpack(fmt, raw)[0]
except ValueError:
raise InvalidModelFormatError('failed to parse {0} value {1} as {2}'.format(key, str(raw), fmt))
setattr(h, key, value)
return h
python类load()的实例源码
def load_object(path, build_fn, *args, **kwargs):
""" load from serialized form or build an object, saving the built
object; kwargs provided to `build_fn`.
"""
save = False
obj = None
if path is not None and os.path.isfile(path):
with open(path, 'rb') as obj_f:
obj = msgpack.load(obj_f, use_list=False, encoding='utf-8')
else:
save = True
if obj is None:
obj = build_fn(*args, **kwargs)
if save and path is not None:
with open(path, 'wb') as obj_f:
msgpack.dump(obj, obj_f)
return obj
def load_data(opt):
with open('SQuAD/meta.msgpack', 'rb') as f:
meta = msgpack.load(f, encoding='utf8')
embedding = torch.Tensor(meta['embedding'])
opt['pretrained_words'] = True
opt['vocab_size'] = embedding.size(0)
opt['embedding_dim'] = embedding.size(1)
opt['pos_size'] = len(meta['vocab_tag'])
opt['ner_size'] = len(meta['vocab_ent'])
with open(args.data_file, 'rb') as f:
data = msgpack.load(f, encoding='utf8')
train = data['train']
data['dev'].sort(key=lambda x: len(x[1]))
dev = [x[:-1] for x in data['dev']]
dev_y = [x[-1] for x in data['dev']]
return train, dev, dev_y, embedding, opt
def read(self, stream):
"""Given a readable file descriptor object (something `load`able by
msgpack or json), read the data, and return the Python representation
of the contents. One-shot reader.
"""
return self.reader.load(stream)
def load(self, stream):
return self.decoder.decode(json.load(stream,
object_pairs_hook=OrderedDict))
def load(self, stream):
return self.decoder.decode(msgpack.load(stream,
object_pairs_hook=OrderedDict))
def load_data(opt):
with open('SQuAD/meta.msgpack', 'rb') as f:
meta = msgpack.load(f, encoding='utf8')
embedding = torch.Tensor(meta['embedding'])
opt['pretrained_words'] = True
opt['vocab_size'] = embedding.size(0)
opt['embedding_dim'] = embedding.size(1)
if not opt['fix_embeddings']:
embedding[1] = torch.normal(means=torch.zeros(opt['embedding_dim']), std=1.)
with open(args.data_file, 'rb') as f:
data = msgpack.load(f, encoding='utf8')
train_orig = pd.read_csv('SQuAD/train.csv')
dev_orig = pd.read_csv('SQuAD/dev.csv')
train = list(zip(
data['trn_context_ids'],
data['trn_context_features'],
data['trn_context_tags'],
data['trn_context_ents'],
data['trn_question_ids'],
train_orig['answer_start_token'].tolist(),
train_orig['answer_end_token'].tolist(),
data['trn_context_text'],
data['trn_context_spans']
))
dev = list(zip(
data['dev_context_ids'],
data['dev_context_features'],
data['dev_context_tags'],
data['dev_context_ents'],
data['dev_question_ids'],
data['dev_context_text'],
data['dev_context_spans']
))
dev_y = dev_orig['answers'].tolist()[:len(dev)]
dev_y = [eval(y) for y in dev_y]
return train, dev, dev_y, embedding, opt
def load_json(cls, f):
"""
Loads model file saved as JSON file from text stream ``f``.
"""
m = cls()
record = json.load(f)
# Load header
if 'header' not in record:
raise InvalidModelFormatError('header section does not exist')
m.header.set(record['header'])
# Load system_data
if 'system' not in record:
raise InvalidModelFormatError('system section does not exist')
m.system.set(record['system'])
# Load user_data
if 'user_raw' in record:
if 'user' in record:
printe('Notice: using "user_raw" record from JSON; "user" record is ignored')
raw = base64.b64decode(record['user_raw'])
try:
m.user = cls.UserContainer.loads(raw)
except UnicodeDecodeError:
printe('Warning: model contains non UTF-8 strings; cannot be loaded')
m.user = cls.UserContainer()
m.user.user_data = None
m._user_raw = raw
elif 'user' in record:
m.user.set(record['user'])
else:
raise InvalidModelFormatError('user or user_raw section does not exist')
return m
def load(cls, f, *args, **kwargs):
# Must be implemented in sub classes.
raise NotImplementedError
def loads(cls, data, *args, **kwargs):
return cls.load(BytesIO(data), *args, **kwargs)
def load_data(folder=data_folder):
opt = {}
with open(folder+"meta.msgpack", 'rb') as f:
meta = msgpack.load(f, encoding='utf8')
embedding = meta['embedding']
opt['pretrained_words'] = True
opt['vocab_size'] = len(embedding)
opt['embedding_dim'] = len(embedding[0])
with open(folder+"data.msgpack", 'rb') as f:
data = msgpack.load(f, encoding='utf8')
with open(folder+ 'dev.csv', 'rb') as f:
charResult = chardet.detect(f.read())
train_orig = pd.read_csv(folder+ 'train.csv', encoding=charResult['encoding'])
dev_orig = pd.read_csv(folder+'dev.csv', encoding=charResult['encoding'])
train = list(zip(
data['trn_context_ids'],data['trn_context_features'],
data['trn_context_tags'],data['trn_context_ents'],data['trn_question_ids'],
train_orig['answer_start_token'].tolist(), train_orig['answer_end_token'].tolist(),
data['trn_context_text'],data['trn_context_spans']
))
dev = list(zip(
data['dev_context_ids'],data['dev_context_features'],data['dev_context_tags'],
data['dev_context_ents'],data['dev_question_ids'],data['dev_context_text'],data['dev_context_spans']
))
dev_y = dev_orig['answers'].tolist()[:len(dev)]
dev_y = [eval(y) for y in dev_y]
# discover lengths
opt['context_len'] = get_max_len(data['trn_context_ids'], data['dev_context_ids'])
opt['feature_len'] = get_max_len(data['trn_context_features'][0], data['dev_context_features'][0])
opt['question_len'] = get_max_len(data['trn_question_ids'], data['dev_question_ids'])
print(train_orig['answer_start_token'].tolist()[:10])
return train, dev, dev_y, embedding, opt
def load_data(opt):
with open(opt["squad_dir"]+'meta.msgpack', 'rb') as f:
meta = msgpack.load(f, encoding='utf8')
embedding = meta['embedding']
opt['pretrained_words'] = True
opt['vocab_size'] = len(embedding)
opt['embedding_dim'] = len(embedding[0])
with open(args.data_file, 'rb') as f:
data = msgpack.load(f, encoding='utf8')
#with open(opt["squad_dir"]+ 'train.csv', 'rb') as f:
# charResult = chardet.detect(f.read())
train_orig = pd.read_csv(opt["squad_dir"]+ 'train.csv')#, encoding=charResult['encoding'])
dev_orig = pd.read_csv(opt["squad_dir"]+'dev.csv')#, encoding=charResult['encoding'])
train = list(zip(
data['trn_context_ids'],data['trn_context_features'],
data['trn_context_tags'],data['trn_context_ents'],data['trn_question_ids'],
train_orig['answer_start_token'].tolist(), train_orig['answer_end_token'].tolist(),
data['trn_context_text'],data['trn_context_spans']
))
dev = list(zip(
data['dev_context_ids'],data['dev_context_features'],data['dev_context_tags'],
data['dev_context_ents'],data['dev_question_ids'],data['dev_context_text'],data['dev_context_spans']
))
dev_y = dev_orig['answers'].tolist()[:len(dev)]
dev_y = [eval(y) for y in dev_y]
# discover lengths
opt['context_len'] = get_max_len(data['trn_context_ids'], data['dev_context_ids'])
opt['feature_len'] = get_max_len(data['trn_context_features'][0], data['dev_context_features'][0])
opt['question_len'] = get_max_len(data['trn_question_ids'], data['dev_question_ids'])
print(train_orig['answer_start_token'].tolist()[:10])
return train, dev, dev_y, embedding, opt
def load_binary(cls, f, validate=True):
"""
Loads Jubatus binary model file from binary stream ``f``.
When ``validate`` is ``True``, the model file format is strictly validated.
"""
m = cls()
checksum = 0
# Load header
h = cls.Header.load(f)
m.header = h
if validate:
checksum = crc32(h.dumps(False), checksum)
# Load system_data
buf = f.read(h.system_data_size)
m.system = cls.SystemContainer.loads(buf)
if validate:
if h.system_data_size != len(buf):
raise InvalidModelFormatError(
'EOF detected while reading system_data: ' +
'expected {0} bytes, got {1} bytes'.format(h.system_data_size, len(buf)))
checksum = crc32(buf, checksum)
# Load user_data
buf = f.read(h.user_data_size)
m.user = cls.UserContainer.loads(buf)
m._user_raw = buf
if validate:
if h.user_data_size != len(buf):
raise InvalidModelFormatError(
'EOF detected while reading user_data: ' +
'expected {0} bytes, got {1} bytes'.format(h.user_data_size, len(buf)))
checksum = crc32(buf, checksum)
if validate:
# Convert the checksum into 32-bit unsigned integer (for Python 2/3 compatibility)
checksum = checksum & 0xffffffff
# Check CRC
if checksum != h.crc32:
raise InvalidModelFormatError(
'CRC32 mismatch: expected {0}, got {1}'.format(checksum, h.crc32))
return m