def __init__(self,
label: Union[str, int],
label_namespace: str = 'labels',
skip_indexing: bool = False) -> None:
self.label = label
self._label_namespace = label_namespace
self._label_id = None
self._maybe_warn_for_namespace(label_namespace)
if skip_indexing:
if not isinstance(label, int):
raise ConfigurationError("In order to skip indexing, your labels must be integers. "
"Found label = {}".format(label))
else:
self._label_id = label
else:
if not isinstance(label, str):
raise ConfigurationError("LabelFields must be passed a string label if skip_indexing=False. "
"Found label: {} with type: {}.".format(label, type(label)))
python类Union()的实例源码
def __init__(self,
labels: Union[List[str], List[int]],
sequence_field: SequenceField,
label_namespace: str = 'labels') -> None:
self.labels = labels
self.sequence_field = sequence_field
self._label_namespace = label_namespace
self._indexed_labels = None
self._maybe_warn_for_namespace(label_namespace)
if len(labels) != sequence_field.sequence_length():
raise ConfigurationError("Label length and sequence length "
"don't match: %d and %d" % (len(labels), sequence_field.sequence_length()))
if all([isinstance(x, int) for x in labels]):
self._indexed_labels = labels
elif not all([isinstance(x, str) for x in labels]):
raise ConfigurationError("SequenceLabelFields must be passed either all "
"strings or all ints. Found labels {} with "
"types: {}.".format(labels, [type(x) for x in labels]))
def from_dataset(cls,
dataset,
min_count: int = 1,
max_vocab_size: Union[int, Dict[str, int]] = None,
non_padded_namespaces: Sequence[str] = DEFAULT_NON_PADDED_NAMESPACES,
pretrained_files: Optional[Dict[str, str]] = None,
only_include_pretrained_words: bool = False) -> 'Vocabulary':
"""
Constructs a vocabulary given a :class:`.Dataset` and some parameters. We count all of the
vocabulary items in the dataset, then pass those counts, and the other parameters, to
:func:`__init__`. See that method for a description of what the other parameters do.
"""
logger.info("Fitting token dictionary from dataset.")
namespace_token_counts: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int))
for instance in tqdm.tqdm(dataset.instances):
instance.count_vocab_items(namespace_token_counts)
return Vocabulary(counter=namespace_token_counts,
min_count=min_count,
max_vocab_size=max_vocab_size,
non_padded_namespaces=non_padded_namespaces,
pretrained_files=pretrained_files,
only_include_pretrained_words=only_include_pretrained_words)
def test_site(url: str, previous_results: dict, remote_host: str = None) -> Dict[str, Dict[str, Union[str, bytes]]]:
# test first mx
try:
hostname = previous_results['mx_records'][0][1]
except (KeyError, IndexError):
return {
'jsonresult': {
'mime_type': 'application/json',
'data': b'',
},
}
jsonresult = run_testssl(hostname, True, remote_host)
return {
'jsonresult': {
'mime_type': 'application/json',
'data': jsonresult,
},
}
def _subscribe_to_topic(self, alias: str, topic: Union[bytes, str]):
'''
Do the actual ZeroMQ subscription of a socket given by its alias to
a specific topic. This method only makes sense to be called on
SUB/SYNC_SUB sockets.
Note that the handler is not set within this function.
'''
topic = topic_to_bytes(topic)
if isinstance(self.address[alias], AgentAddress):
self.socket[alias].setsockopt(zmq.SUBSCRIBE, topic)
elif isinstance(self.address[alias], AgentChannel):
channel = self.address[alias]
sub_address = channel.receiver
treated_topic = channel.uuid + topic
self.socket[sub_address].setsockopt(zmq.SUBSCRIBE, treated_topic)
else:
raise NotImplementedError('Unsupported address type %s!' %
self.address[alias])
def topics_to_bytes(handlers: Dict[Union[bytes, str], Any], uuid: bytes = b''):
'''
Given some pairs topic/handler, leaves them prepared for making the actual
ZeroMQ subscription.
Parameters
----------
handlers
Contains pairs "topic - handler".
uuid
uuid of the SYNC_PUB/SYNC_SUB channel (if applies). For normal
PUB/SUB communication, this should be `b''`.
Returns
-------
Dict[bytes, Any]
'''
curated_handlers = {}
for topic, value in handlers.items():
topic = topic_to_bytes(topic)
curated_handlers[uuid + topic] = value
return curated_handlers
def _parse_simple_yaml(buf: bytes) -> Stats:
data = buf.decode('ascii')
assert data[:4] == '---\n'
data = data[4:] # strip YAML head
stats = {}
for line in data.splitlines():
key, value = line.split(': ', 1)
try:
v = int(value) # type: Union[int, str]
except ValueError:
v = value
stats[key] = v
return stats
def reverse_session_view(request: HttpRequest, pk: int) -> Union[HttpRequest, HttpResponseRedirect]:
session = get_object_or_404(CashdeskSession, pk=pk)
if request.method == 'POST':
try:
reverse_session(session)
except FlowError as e:
messages.error(request, str(e))
else:
messages.success(request, _('All transactions in the session have been cancelled.'))
return redirect('backoffice:session-detail', pk=pk)
elif request.method == 'GET':
return render(request, 'backoffice/reverse_session.html', {
'session': session,
})
def rep1sep(parser: Union[Parser, Sequence[Input]], separator: Union[Parser, Sequence[Input]]) \
-> RepeatedOnceSeparatedParser:
"""Match a parser one or more times separated by another parser.
This matches repeated sequences of ``parser`` separated by ``separator``.
If there is at least one match, a list containing the values of the
``parser`` matches is returned. The values from ``separator`` are discarded.
If it does not match ``parser`` at all, it fails.
Args:
parser: Parser or literal
separator: Parser or literal
"""
if isinstance(parser, str):
parser = lit(parser)
if isinstance(separator, str):
separator = lit(separator)
return RepeatedOnceSeparatedParser(parser, separator)
def repsep(parser: Union[Parser, Sequence[Input]], separator: Union[Parser, Sequence[Input]]) \
-> RepeatedSeparatedParser:
"""Match a parser zero or more times separated by another parser.
This matches repeated sequences of ``parser`` separated by ``separator``. A
list is returned containing the value from each match of ``parser``. The
values from ``separator`` are discarded. If there are no matches, an empty
list is returned.
Args:
parser: Parser or literal
separator: Parser or literal
"""
if isinstance(parser, str):
parser = lit(parser)
if isinstance(separator, str):
separator = lit(separator)
return RepeatedSeparatedParser(parser, separator)
def select_channel(
self,
versions: typing.Set[CustomVersion],
update_channel: str = channel.STABLE
) -> typing.Union[CustomVersion, None]:
"""
Selects the latest version, equals or higher than "channel"
Args:
versions: versions to select from
update_channel: member of :class:`Channel`
Returns: latest version or None
"""
LOGGER.debug(f'selecting latest version amongst {len(versions)}; active channel: {str(channel)}')
options = list(self.filter_channel(versions, update_channel))
if options:
latest = max(options)
return latest
LOGGER.debug('no version passed the test')
return None
def load_yaml(fname: str) -> Union[List, Dict]:
"""Load a YAML file."""
try:
with open(fname, encoding='utf-8') as conf_file:
# If configuration file is empty YAML returns None
# We convert that to an empty dict
return yaml.load(conf_file, Loader=SafeLineLoader) or {}
except yaml.YAMLError as exc:
logger.error(exc)
raise ScarlettError(exc)
except UnicodeDecodeError as exc:
logger.error('Unable to read file %s: %s', fname, exc)
raise ScarlettError(exc)
# def clear_secret_cache() -> None:
# """Clear the secret cache.
#
# Async friendly.
# """
# __SECRET_CACHE.clear()
def select(self, query: str, values: Union[List, Dict],
db_name: str = 'default') -> List[DictRow]:
return await self._select(query=query, values=values, db_name=db_name)
def first(self, query: str, values: Union[List, Dict],
db_name: str = 'default') -> Optional[DictRow]:
return await self._first(query=query, values=values, db_name=db_name)
def insert(self, query: str, values: Union[List, Dict],
db_name: str = 'default', returning: bool = False):
return await self._execute(query=query, values=values, db_name=db_name, returning=returning)
def update(self, query: str, values: Union[List, Dict],
db_name: str = 'default', returning: bool = False):
return await self._execute(query=query, values=values, db_name=db_name, returning=returning)
def delete(self, query: str, values: Union[List, Dict], db_name: str = 'default'):
return await self._execute(query=query, values=values, db_name=db_name)
def _select(self, query: str, values: Union[List, Dict], db_name: str = 'default'):
dbs = self.dbs[db_name]
pool = dbs.get('slave') or dbs.get('master')
if pool is None:
raise RuntimeError('db {} master is not initialized'.format(db_name))
async with pool.acquire() as conn:
async with conn.cursor(cursor_factory=DictCursor) as cursor:
await cursor.execute(query, values)
return await cursor.fetchall()
def _first(self, query: str, values: Union[List, Dict], db_name: str = 'default'):
dbs = self.dbs[db_name]
pool = dbs.get('slave') or dbs.get('master')
if pool is None:
raise RuntimeError('db {} master is not initialized'.format(db_name))
async with pool.acquire() as conn:
async with conn.cursor(cursor_factory=DictCursor) as cursor:
await cursor.execute(query, values)
return await cursor.fetchone()
def _calc_log_prob_scores(self) -> List[Union[None, float]]:
"""Get log likelihood scores by calling RNNLM
"""
textfile = tempfile.NamedTemporaryFile(delete=True)
content = '\n'.join([''.join(ts) for ts in self.tss]) + '\n'
textfile.write(str.encode(content))
textfile.seek(0)
command = ['rnnlm',
'-rnnlm',
self.rnnlm_model_path,
'-test',
textfile.name]
process = Popen(command, stdout=PIPE, stderr=PIPE)
output, err = process.communicate()
lines = [line.strip() for line in output.decode('UTF-8').split('\n')
if line.strip() != '']
scores = []
for line in lines:
if line == const.OUT_OF_VOCABULARY:
scores.append(None)
else:
try:
score = float(line)
scores.append(score)
except ValueError:
pass
textfile.close()
return scores