def table_iter(pairs, # type: List[BytesPair]
uniqueidx, # type: int
enc='utf-8', # type: str
fromdesc='', # type: str
todesc='', # type: str
numlines=2, # type: int
wrapcolumn=0 # type: int
):
# type: (...) -> Iterator[Tuple[str, str, str]]
htmldiffer = HtmlMultiDiff(tabsize=8, wrapcolumn=wrapcolumn)
htmldiffer.uniqueidx = uniqueidx
table = htmldiffer.table_from_pairs(pairs,
enc,
fromdesc=fromdesc,
todesc=todesc,
context=True,
numlines=numlines)
for tablestart, tbody, tableend in iter_tbodies(table):
yield tablestart, tbody, tableend
python类Iterator()的实例源码
def sentence_iterator(self, file_path: str) -> Iterator[OntonotesSentence]:
"""
An iterator over the sentences in an individual CONLL formatted file.
"""
with codecs.open(file_path, 'r', encoding='utf8') as open_file:
conll_rows = []
for line in open_file:
line = line.strip()
if line != '' and not line.startswith('#'):
conll_rows.append(line)
else:
if not conll_rows:
continue
else:
yield self._conll_rows_to_sentence(conll_rows)
conll_rows = []
def read_android(self) -> Iterator[DeviceConfig]:
"""Read Android-specific database file."""
_LOGGER.info("Reading tokens from Android DB")
c = self.conn.execute("SELECT * FROM devicerecord WHERE token IS NOT '';")
for dev in c.fetchall():
if self.dump_raw:
BackupDatabaseReader.dump_raw(dev)
ip = dev['localIP']
mac = dev['mac']
model = dev['model']
name = dev['name']
token = dev['token']
config = DeviceConfig(name=name, ip=ip, mac=mac,
model=model, token=token)
yield config
def read_tokens(self, db) -> Iterator[DeviceConfig]:
"""Read device information out from a given database file.
:param str db: Database file"""
self.db = db
_LOGGER.info("Reading database from %s" % db)
self.conn = sqlite3.connect(db)
self.conn.row_factory = sqlite3.Row
with self.conn:
is_android = self.conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='devicerecord';").fetchone() is not None
is_apple = self.conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='ZDEVICE'").fetchone() is not None
if is_android:
yield from self.read_android()
elif is_apple:
yield from self.read_apple()
else:
_LOGGER.error("Error, unknown database type!")
def shuffle_and_batch(items: List[T], batch_size: int,
rng: Optional[random.Random] = None) \
-> Iterator[List[T]]:
"""Optionally shuffles and batches items in a list.
Args:
- items: List of items to shuffle & batch.
- batch_size: size of batches.
- rng: random number generator if items should be shuffles, else None.
Returns: Batch iterator
"""
todo = list(range(len(items)))
if rng is not None:
rng.shuffle(todo)
while todo:
indices = todo[:batch_size]
todo = todo[batch_size:]
items_batch = [items[i] for i in indices]
yield items_batch
def iter_grid(
cls,
min_pos: 'Vec',
max_pos: 'Vec',
stride: int=1,
) -> Iterator['Vec']:
"""Loop over points in a bounding box. All coordinates should be integers.
Both borders will be included.
"""
min_x, min_y, min_z = map(int, min_pos)
max_x, max_y, max_z = map(int, max_pos)
for x in range(min_x, max_x + 1, stride):
for y in range(min_y, max_y + 1, stride):
for z in range(min_z, max_z + 1, stride):
yield cls(x, y, z)
def elf_file_filter(paths: Iterator[str]) -> Iterator[Tuple[str, ELFFile]]:
"""Filter through an iterator of filenames and load up only ELF
files
"""
for path in paths:
if path.endswith('.py'):
continue
else:
try:
with open(path, 'rb') as f:
candidate = ELFFile(f)
yield path, candidate
except ELFError:
# not an elf file
continue
def items(self, humanise: bool=True, precision: int=2) -> Iterator[dict]:
"""Returns an iterator for scanned items list. It doesn't return the
internal _items list because we don't want it to be modified outside.
Blocks until the scanning operation has been completed on first access.
:param humanise: Humanise flag to format results (defaults to True)
:type humanise: bool
:param precision: The floating precision of the human-readable size format (defaults to 2).
:type precision: int
:return: Iterator for the internal _items list.
:rtype: iterator
"""
self._await()
# Don't humanise
if humanise is False:
return iter(self._items)
# Humanise
humanise_item = partial(self._humanise_item, precision=precision)
return map(humanise_item, self._items)
def read_texts(path: str, source_type: FileType) -> Iterator[str]:
"""
?????????? ???????.
:param path: ???? ? ?????/?????.
:param source_type: ??? ??????.
"""
paths = Reader.get_paths(path, source_type.value)
for filename in paths:
with open(filename, "r", encoding="utf-8") as file:
if source_type == FileType.XML:
for elem in Reader.__xml_iter(file, 'item'):
yield elem.find(".//text").text
elif source_type == FileType.JSON:
# TODO: ??????? ???????
j = json.load(file)
for item in j['items']:
yield item['text']
elif source_type == FileType.RAW:
text = file.read()
for t in text.split(RAW_SEPARATOR):
yield t
def get_paths(path: str, ext: str) -> Iterator[str]:
"""
????????? ???? ?????? ????????? ???? ?? ????????? ????.
:param path: ???? ? ?????/?????.
:param ext: ????????? ??????????.
"""
if os.path.isfile(path):
if ext == os.path.splitext(path)[1]:
yield path
else:
for root, folders, files in os.walk(path):
for file in files:
if ext == os.path.splitext(file)[1]:
yield os.path.join(root, file)
for folder in folders:
return Reader.get_paths(folder, ext)
def __init__(self,
path: str,
vocab: Optional[Dict[str, int]],
add_bos: bool = False,
limit: Optional[int] = None) -> None:
self.path = path
self.vocab = vocab
self.bos_id = None
if vocab is not None:
assert C.UNK_SYMBOL in vocab
assert vocab[C.PAD_SYMBOL] == C.PAD_ID
assert C.BOS_SYMBOL in vocab
assert C.EOS_SYMBOL in vocab
self.bos_id = vocab[C.BOS_SYMBOL]
else:
check_condition(not add_bos, "Adding a BOS symbol requires a vocabulary")
self.add_bos = add_bos
self.limit = limit
self._iter = None # type: Optional[Iterator]
self._iterated_once = False
self.count = 0
self._next = None
def get_torch_num_workers(num_workers: int):
"""turn an int into a useful number of workers for a pytorch DataLoader.
-1 means "use all CPU's", -2, means "use all but 1 CPU", etc.
Note: 0 is interpreted by pytorch as doing data loading in the main process, while any positive number spawns a
new process. We do not allow more processes to spawn than there are CPU's."""
num_cpu = cpu_count()
if num_workers < 0:
n_workers = num_cpu + 1 + num_workers
if n_workers < 0:
print("Warning: {} fewer workers than the number of CPU's were specified, but there are only {} CPU's; "
"running data loading in the main process (num_workers = 0).".format(num_workers + 1, num_cpu))
num_workers = max(0, n_workers)
if num_workers > num_cpu:
print("Warning, `num_workers` is {} but only {} CPU's are available; "
"using this number instead".format(num_workers, num_cpu))
return min(num_workers, num_cpu)
#####################################################################
# Iterator utils #
#####################################################################
def __iter__(self) -> Iterator[str]:
'''Return an iterator over the lines written to stdout. May only be called once! Might raise a SolverSubprocessError.'''
assert not self.iterating, 'You may only iterate once over a single DlvhexLineReader instance.'
self.iterating = True
# Requirement: dlvhex2 needs to flush stdout after every line
with io.TextIOWrapper(self.process.stdout, encoding=self.stdout_encoding) as stdout_lines:
for line in stdout_lines:
yield line
# Tell dlvhex2 to prepare the next answer set
if not self.process.stdin.closed:
self.process.stdin.write(b'\n')
self.process.stdin.flush()
else:
break
# We've exhausted stdout, so either:
# 1. we got all answer sets, or
# 2. an error occurred,
# and dlvhex closed stdout (and probably terminated).
# Give it a chance to terminate gracefully.
try:
self.process.wait(timeout=0.005) # type: ignore (mypy does not know about `timeout`)
except subprocess.TimeoutExpired: # type: ignore (mypy does not know about `TimeoutExpired`)
pass
self.close()
def get_profile_posts(self, profile_metadata: Dict[str, Any]) -> Iterator[Post]:
"""Retrieve all posts from a profile."""
profile_name = profile_metadata['user']['username']
profile_id = int(profile_metadata['user']['id'])
yield from (Post(self, node, profile=profile_name, profile_id=profile_id)
for node in profile_metadata['user']['media']['nodes'])
has_next_page = profile_metadata['user']['media']['page_info']['has_next_page']
end_cursor = profile_metadata['user']['media']['page_info']['end_cursor']
while has_next_page:
# We do not use self.graphql_node_list() here, because profile_metadata
# lets us obtain the first 12 nodes 'for free'
data = self.graphql_query(17888483320059182, {'id': profile_metadata['user']['id'],
'first': 200,
'after': end_cursor},
'https://www.instagram.com/{0}/'.format(profile_name))
media = data['data']['user']['edge_owner_to_timeline_media']
yield from (Post(self, edge['node'], profile=profile_name, profile_id=profile_id)
for edge in media['edges'])
has_next_page = media['page_info']['has_next_page']
end_cursor = media['page_info']['end_cursor']
def dcos_cluster(
oss_artifact: Path,
cluster_backend: ClusterBackend,
) -> Iterator[Cluster]:
"""
Return a `Cluster`.
This is module scoped as we do not intend to modify the cluster in ways
that make tests interfere with one another.
"""
with Cluster(
cluster_backend=cluster_backend,
masters=1,
agents=0,
public_agents=0,
) as cluster:
cluster.install_dcos_from_path(oss_artifact, log_output_live=True)
yield cluster
def _latest_version_from_object_names(object_names: typing.Iterator[str]) -> str:
dead_versions = set() # type: typing.Set[str]
all_versions = set() # type: typing.Set[str]
set_checks = [
(DSS_BUNDLE_TOMBSTONE_REGEX, dead_versions),
(DSS_BUNDLE_KEY_REGEX, all_versions),
]
for object_name in object_names:
for regex, version_set in set_checks:
match = regex.match(object_name)
if match:
_, version = match.groups()
version_set.add(version)
break
version = None
for current_version in (all_versions - dead_versions):
if version is None or current_version > version:
version = current_version
return version
def train(dataset: DataSet, n_iter: int = 3000, batch_size: int = 25) -> Iterator[AutoEncoder]:
n = dataset.size
input_dimension = dataset.input.shape[1]
hidden_dimension = 2
model = AutoEncoder(input_dimension, hidden_dimension)
optimizer = optimizers.Adam()
optimizer.setup(model)
for j in range(n_iter):
shuffled = np.random.permutation(n)
for i in range(0, n, batch_size):
indices = shuffled[i:i+batch_size]
x = Variable(dataset.input[indices])
model.cleargrads()
loss = model(x)
loss.backward()
optimizer.update()
yield model
def nest_annotations(annotations: Iterator[Annotation],
text_length: int) -> List[NestableAnnotation]:
"""Converts overlapping annotations into a nested version."""
in_order = sorted(annotations, key=lambda a: (a.start, -a.end))
# Easier to operate on a single root, even if we'll remove it later.
root = last = NestableAnnotation(PlainText(start=0, end=text_length), None)
for anote in in_order:
# We're not allowing non-nested overlapping annotations, so we won't
# compare ends when determining nesting
while anote not in last:
last = last.parent
# Enforce all annotations to be nested rather than overlapping
anote.end = min(anote.end, last.end)
last = NestableAnnotation(anote, last)
root.wrap_unwrapped()
return root.children
def build_lca_map(gen: typing.Iterator, tree: Taxonomy) -> dict:
"""
Build a last common ancestor dictionary
:param sam_inf: path to SAM infile
:param extract_ncbi_tid: function to extract ncbi_tid
:param tree: NCBITree
:return: dict key (query name: string) value (ncbi_tid: int)
"""
lca_map = {}
for qname, rname in gen:
tax = tree(rname)
if qname in lca_map:
current_tax = lca_map[qname]
if current_tax:
if current_tax != tax:
lca_map[qname] = least_common_ancestor((tax, current_tax))
else:
lca_map[qname] = tax
return lca_map
def predict_from_dataset(self, dataset: Dataset,
show_eos: bool=True,
use_queue: bool=True,
**kwargs) -> Iterator[Tuple[str, str, str]]:
if use_queue:
evaluator = self._predict_from_dataset_queue(dataset, **kwargs)
else:
evaluator = self._predict_from_dataset_feed(dataset, **kwargs)
for source, target, translation in evaluator:
# unpack and decode result
yield from zip(
self.dataset.decode_as_batch(source,
show_eos=show_eos),
self.dataset.decode_as_batch(target,
show_eos=show_eos),
self.dataset.decode_as_batch(translation,
show_eos=show_eos)
)
def __iter__(self) -> Iterator[Tuple[str, str]]:
with NLTKEnv() as nltk_env:
nltk_env.download('perluniprops')
nltk_env.download('comtrans')
from nltk.corpus import comtrans
from nltk.tokenize.moses import MosesDetokenizer
als = comtrans.aligned_sents(self._comtrans_string())
source_detokenizer = MosesDetokenizer(lang=self._source_lang)
target_detokenizer = MosesDetokenizer(lang=self._target_lang)
for source, target in self._comtrans_maybe_swap(als):
source = source_detokenizer.detokenize(source, return_str=True)
target = target_detokenizer.detokenize(target, return_str=True)
if self._length_checker(source, target):
yield (source, target)
def __iter__(self) -> Iterator[Tuple[str, str]]:
with EuroparlCache() as europarl_cache:
europarl_cache.download(name='europarl-v7.tgz', url=_v7_url)
# peak in tarball
filepath = europarl_cache.filepath('europarl-v7.tgz')
source_filepath, target_filepath = self._files()
with tar_extract_file(filepath, source_filepath) as source_file, \
tar_extract_file(filepath, target_filepath) as target_file:
observations = 0
for source, target in zip(source_file, target_file):
source, target = (source.rstrip(), target.rstrip())
if self._length_checker(source, target):
yield (source, target)
observations += 1
if self._max_observations is not None and \
observations >= self._max_observations:
break
def _build_dataset(self) -> Iterator[Tuple[str, str]]:
length_type = size_to_unsigned_type(self._max_length)
for _ in range(self._examples):
length = self._random.randint(
self._min_length, self._max_length + 1,
dtype=length_type
)
target = self._random.randint(
0, self._digits,
size=length, dtype=np.int8
)
source = text_map[target]
target_str = ''.join(target.astype(np.str))
source_str = ' '.join(source)
yield (source_str, target_str)
def ca_develop(network: FeedForwardNetwork) -> Iterator[ToroidalCellGrid2D]:
def transition_f(inputs_discrete_values: Sequence[CELL_STATE_T]) -> CELL_STATE_T:
neighbour_values, xy_values = inputs_discrete_values[:-2], inputs_discrete_values[-2:]
if all((x == initial_grid.dead_cell) for x in neighbour_values):
return initial_grid.dead_cell
inputs_float_values = tuple(state_normalization_rules[n] for n in neighbour_values) + \
tuple(coord_normalization_rules[n] for n in xy_values)
outputs = network.serial_activate(inputs_float_values)
return max(zip(alphabet, outputs), key=itemgetter(1))[0]
yield initial_grid
for grid in iterate_ca_n_times_or_until_cycle_found(
initial_grid=initial_grid,
transition_f=transition_f,
n=iterations,
iterate_f=iterate_ca_once_with_coord_inputs
):
yield grid
def create_initial_population(neat_config: CPPNNEATConfig) -> Iterator[Genome]:
for _ in range(neat_config.pop_size):
g_id = uuid4().int
g = neat_config.genotype.create_unconnected(g_id, neat_config)
hidden_nodes = neat_config.initial_hidden_nodes
if hidden_nodes:
g.add_hidden_nodes(hidden_nodes)
if neat_config.initial_connection == 'fs_neat':
g.connect_fs_neat()
elif neat_config.initial_connection == 'fully_connected':
g.connect_full()
elif neat_config.initial_connection == 'partial':
if callable(neat_config.connection_fraction):
fraction = neat_config.connection_fraction()
else:
fraction = neat_config.connection_fraction
g.connect_partial(fraction)
yield g
def sigma_scaled(population: List[Genome], **kwargs) -> Iterator[PAIR_T]:
try:
assert len(population) > 1
except AssertionError:
raise TooFewIndividuals
fitnesses = tuple(x.fitness for x in population)
try:
assert any(f > 0.0 for f in fitnesses)
except AssertionError:
return random_choice(population)
sigma = stdev(fitnesses)
average_fitness = mean(fitnesses)
expected_value_func = lambda x: 1 if sigma == 0 else 1 + ((x - average_fitness) / (2 * sigma))
sigma_sum = sum(expected_value_func(x) for x in fitnesses)
scaling_func = lambda x: expected_value_func(x) / sigma_sum
return roulette(population=population, scaling_func=scaling_func, **kwargs)
def tournament(population: List[Genome], group_size: int, epsilon: float, **kwargs) -> Iterator[PAIR_T]:
def get_one(group: Sequence[Genome]) -> Genome:
r = random()
if r < epsilon:
return choice(group)
return max(group, key=attrgetter('fitness'))
try:
assert len(population) > 1
except AssertionError:
raise TooFewIndividuals
while True:
pool = list(population) # make a shallow copy
group_a = sample(pool, group_size)
a = get_one(group_a)
pool.remove(a)
group_b = sample(pool, group_size)
b = get_one(group_b)
yield (a, b)
def find_pattern_partial_matches(grid: CellGrid2D, pattern: PATTERN_T) -> Iterator[float]:
(x_min, y_min), (x_max, y_max) = grid.get_extreme_coords(pad=1)
pattern_h, pattern_w = len(pattern), len(pattern[0])
pattern_area = pattern_h * pattern_w
for y in range(y_min, y_max):
for x in range(x_min, x_max):
rectangle = grid.get_rectangle(
x_range=(x, x + pattern_w),
y_range=(y, y + pattern_h),
)
if all(all(x == grid.dead_cell for x in row) for row in rectangle):
yield 0.0
correct_count = count_correct_cells(test_pattern=rectangle, target_pattern=pattern)
yield (correct_count / pattern_area)
def _match_sequence_variable(self, wildcard: Wildcard, transition: _Transition) -> Iterator[_State]:
min_count = wildcard.min_count
if len(self.subjects) < min_count:
return
matched_subject = []
for _ in range(min_count):
matched_subject.append(self.subjects.popleft())
while True:
if self.associative[-1] and wildcard.fixed_size:
assert min_count == 1, "Fixed wildcards with length != 1 are not supported."
if len(matched_subject) > 1:
wrapped = self.associative[-1](*matched_subject)
else:
wrapped = matched_subject[0]
else:
if len(matched_subject) == 0 and wildcard.optional is not None:
wrapped = wildcard.optional
else:
wrapped = tuple(matched_subject)
yield from self._check_transition(transition, wrapped, False)
if not self.subjects:
break
matched_subject.append(self.subjects.popleft())
self.subjects.extendleft(reversed(matched_subject))
def _match_with_bipartite(
self,
subject_ids: MultisetOfInt,
pattern_set: MultisetOfInt,
substitution: Substitution,
) -> Iterator[Tuple[Substitution, MultisetOfInt]]:
bipartite = self._build_bipartite(subject_ids, pattern_set)
for matching in enum_maximum_matchings_iter(bipartite):
if len(matching) < len(pattern_set):
break
if not self._is_canonical_matching(matching):
continue
for substs in itertools.product(*(bipartite[edge] for edge in matching.items())):
try:
bipartite_substitution = substitution.union(*substs)
except ValueError:
continue
matched_subjects = Multiset(subexpression for subexpression, _ in matching)
yield bipartite_substitution, matched_subjects