python类viewkeys()的实例源码

metadata.py 文件源码 项目:deb-python-cassandra-driver 作者: openstack 项目源码 文件源码 阅读 33 收藏 0 点赞 0 评论 0
def _all_as_cql(self):
        ret = self.as_cql_query(formatted=True)
        ret += ";"

        for index in self.indexes.values():
            ret += "\n%s;" % index.as_cql_query()

        for trigger_meta in self.triggers.values():
            ret += "\n%s;" % (trigger_meta.as_cql_query(),)

        for view_meta in self.views.values():
            ret += "\n\n%s;" % (view_meta.as_cql_query(formatted=True),)

        if self.extensions:
            registry = _RegisteredExtensionType._extension_registry
            for k in six.viewkeys(registry) & self.extensions:  # no viewkeys on OrderedMapSerializeKey
                ext = registry[k]
                cql = ext.after_table_cql(self, k, self.extensions[k])
                if cql:
                    ret += "\n\n%s" % (cql,)

        return ret
formatter.py 文件源码 项目:pytablereader 作者: thombashi 项目源码 文件源码 阅读 34 收藏 0 点赞 0 评论 0
def to_table_data(self):
        """
        :raises ValueError:
        :raises pytablereader.error.ValidationError:
        """

        self._validate_source_data()

        attr_name_set = set()
        for json_record in self._buffer:
            attr_name_set = attr_name_set.union(six.viewkeys(json_record))

        self._loader.inc_table_count()

        yield TableData(
            table_name=self._make_table_name(),
            header_list=sorted(attr_name_set),
            record_list=self._buffer,
            quoting_flags=self._loader.quoting_flags)
formatter.py 文件源码 项目:pytablereader 作者: thombashi 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def to_table_data(self):
        """
        :raises ValueError:
        :raises pytablereader.error.ValidationError:
        """

        self._validate_source_data()
        self._loader.inc_table_count()

        header_list = sorted(six.viewkeys(self._buffer))

        yield TableData(
            table_name=self._make_table_name(),
            header_list=header_list,
            record_list=zip(
                *[self._buffer.get(header) for header in header_list]),
            quoting_flags=self._loader.quoting_flags)
formatter.py 文件源码 项目:pytablereader 作者: thombashi 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def to_table_data(self):
        """
        :raises ValueError:
        :raises pytablereader.error.ValidationError:
        """

        self._validate_source_data()

        for table_key, json_record_list in six.iteritems(self._buffer):
            attr_name_set = set()
            for json_record in json_record_list:
                attr_name_set = attr_name_set.union(six.viewkeys(json_record))

            self._loader.inc_table_count()
            self._table_key = table_key

            yield TableData(
                table_name=self._make_table_name(),
                header_list=sorted(attr_name_set),
                record_list=json_record_list,
                quoting_flags=self._loader.quoting_flags)
formatter.py 文件源码 项目:pytablereader 作者: thombashi 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def to_table_data(self):
        """
        :raises ValueError:
        :raises pytablereader.error.ValidationError:
        """

        self._validate_source_data()

        for table_key, json_record_list in six.iteritems(self._buffer):
            header_list = sorted(six.viewkeys(json_record_list))

            self._loader.inc_table_count()
            self._table_key = table_key

            yield TableData(
                table_name=self._make_table_name(),
                header_list=header_list,
                record_list=zip(
                    *[json_record_list.get(header) for header in header_list]),
                quoting_flags=self._loader.quoting_flags)
stat_density.py 文件源码 项目:plotnine 作者: has2k1 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def setup_params(self, data):
        params = self.params.copy()
        lookup = {
            'biweight': 'biw',
            'cosine': 'cos',
            'cosine2': 'cos2',
            'epanechnikov': 'epa',
            'gaussian': 'gau',
            'triangular': 'tri',
            'triweight': 'triw',
            'uniform': 'uni'}

        with suppress(KeyError):
            params['kernel'] = lookup[params['kernel'].lower()]

        if params['kernel'] not in six.viewvalues(lookup):
            msg = ("kernel should be one of {}. "
                   "You may use the abbreviations {}")
            raise PlotnineError(msg.format(six.viewkeys(lookup),
                                           six.viewvalues(lookup)))

        return params
geom.py 文件源码 项目:plotnine 作者: has2k1 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def _verify_arguments(self, kwargs):
        """
        Verify arguments passed to the geom
        """
        keys = six.viewkeys
        unknown = (keys(kwargs) -
                   self.aesthetics() -                # geom aesthetics
                   keys(self.DEFAULT_PARAMS) -        # geom parameters
                   self._stat.aesthetics() -          # stat aesthetics
                   keys(self._stat.DEFAULT_PARAMS) -  # stat parameters
                   {'data', 'mapping',                # layer parameters
                    'show_legend', 'inherit_aes'})    # layer parameters
        if unknown:
            msg = ("Parameters {}, are not understood by "
                   "either the geom, stat or layer.")
            raise PlotnineError(msg.format(unknown))
predicates.py 文件源码 项目:catalyst 作者: enigmampc 项目源码 文件源码 阅读 59 收藏 0 点赞 0 评论 0
def assert_dict_equal(result, expected, path=(), msg='', **kwargs):
    _check_sets(
        viewkeys(result),
        viewkeys(expected),
        msg,
        path + ('.%s()' % ('viewkeys' if PY2 else 'keys'),),
        'key',
    )

    failures = []
    for k, (resultv, expectedv) in iteritems(dzip_exact(result, expected)):
        try:
            assert_equal(
                resultv,
                expectedv,
                path=path + ('[%r]' % (k,),),
                msg=msg,
                **kwargs
            )
        except AssertionError as e:
            failures.append(str(e))

    if failures:
        raise AssertionError('\n'.join(failures))
test_assets.py 文件源码 项目:catalyst 作者: enigmampc 项目源码 文件源码 阅读 32 收藏 0 点赞 0 评论 0
def test_blocked_lookup_symbol_query(self):
        # we will try to query for more variables than sqlite supports
        # to make sure we are properly chunking on the client side
        as_of = pd.Timestamp('2013-01-01', tz='UTC')
        # we need more sids than we can query from sqlite
        nsids = SQLITE_MAX_VARIABLE_NUMBER + 10
        sids = range(nsids)
        frame = pd.DataFrame.from_records(
            [
                {
                    'sid': sid,
                    'symbol':  'TEST.%d' % sid,
                    'start_date': as_of.value,
                    'end_date': as_of.value,
                    'exchange': uuid.uuid4().hex
                }
                for sid in sids
            ]
        )
        self.write_assets(equities=frame)
        assets = self.asset_finder.retrieve_equities(sids)
        assert_equal(viewkeys(assets), set(sids))
linthompsamp.py 文件源码 项目:striatum 作者: ntucllab 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def _linthompsamp_score(self, context):
        """Thompson Sampling"""
        action_ids = list(six.viewkeys(context))
        context_array = np.asarray([context[action_id]
                                    for action_id in action_ids])
        model = self._model_storage.get_model()
        B = model['B']  # pylint: disable=invalid-name
        mu_hat = model['mu_hat']
        v = self.R * np.sqrt(24 / self.epsilon
                             * self.context_dimension
                             * np.log(1 / self.delta))
        mu_tilde = self.random_state.multivariate_normal(
            mu_hat.flat, v**2 * np.linalg.inv(B))[..., np.newaxis]
        estimated_reward_array = context_array.dot(mu_hat)
        score_array = context_array.dot(mu_tilde)

        estimated_reward_dict = {}
        uncertainty_dict = {}
        score_dict = {}
        for action_id, estimated_reward, score in zip(
                action_ids, estimated_reward_array, score_array):
            estimated_reward_dict[action_id] = float(estimated_reward)
            score_dict[action_id] = float(score)
            uncertainty_dict[action_id] = float(score - estimated_reward)
        return estimated_reward_dict, uncertainty_dict, score_dict
localdisk_service.py 文件源码 项目:treadmill 作者: Morgan-Stanley 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def synchronize(self):
        """Make sure that all stale volumes are removed.
        """
        modified = False
        for uniqueid in six.viewkeys(self._volumes.copy()):
            if not self._volumes[uniqueid].pop('stale', False):
                continue
            modified = True
            # This is a stale volume, destroy it.
            self._destroy_volume(uniqueid)

        if not modified:
            return

        # Now that we successfully removed a volume, retry all the pending
        # resources.
        for pending_id in self._pending:
            self._retry_request(pending_id)
        self._pending = []

        # We just destroyed a volume, refresh cached status from LVM and notify
        # the service of the availability of the new status.
        self._vg_status = localdiskutils.refresh_vg_status(
            localdiskutils.TREADMILL_VG
        )
resources.py 文件源码 项目:txdarn 作者: markrwilliams 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def __init__(self, policies=None):
        if policies is None:
            policies = self.policies

        if not isinstance(policies, compat.Mapping):
            raise ValueError("policies must be a mapping of bytes"
                             " method names to sequence of policies.")

        allowedMethods = getattr(self, 'allowedMethods', None)
        if not allowedMethods:
            raise ValueError("instance must have allowedMethods")

        required = set(allowedMethods)
        available = six.viewkeys(policies)
        missing = required - available

        if missing:
            raise ValueError("missing methods: {}".format(missing))

        # adapt any policies we have to our resource
        self._actingPolicies = {method: tuple(p.forResource(self)
                                              for p in methodPolicies)
                                for method, methodPolicies in policies.items()}
b301_b302_b305.py 文件源码 项目:flake8-bugbear 作者: PyCQA 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def this_is_okay():
    d = {}
    iterkeys(d)
    six.iterkeys(d)
    six.itervalues(d)
    six.iteritems(d)
    six.iterlists(d)
    six.viewkeys(d)
    six.viewvalues(d)
    six.viewlists(d)
    itervalues(d)
    future.utils.iterkeys(d)
    future.utils.itervalues(d)
    future.utils.iteritems(d)
    future.utils.iterlists(d)
    future.utils.viewkeys(d)
    future.utils.viewvalues(d)
    future.utils.viewlists(d)
    six.next(d)
    builtins.next(d)
metadata.py 文件源码 项目:python-dse-driver 作者: datastax 项目源码 文件源码 阅读 31 收藏 0 点赞 0 评论 0
def _all_as_cql(self):
        ret = self.as_cql_query(formatted=True)
        ret += ";"

        for index in self.indexes.values():
            ret += "\n%s;" % index.as_cql_query()

        for trigger_meta in self.triggers.values():
            ret += "\n%s;" % (trigger_meta.as_cql_query(),)

        for view_meta in self.views.values():
            ret += "\n\n%s;" % (view_meta.as_cql_query(formatted=True),)

        if self.extensions:
            registry = _RegisteredExtensionType._extension_registry
            for k in six.viewkeys(registry) & self.extensions:  # no viewkeys on OrderedMapSerializeKey
                ext = registry[k]
                cql = ext.after_table_cql(self, k, self.extensions[k])
                if cql:
                    ret += "\n\n%s" % (cql,)

        return ret
functional.py 文件源码 项目:zipline-chinese 作者: zhanghan1990 项目源码 文件源码 阅读 32 收藏 0 点赞 0 评论 0
def dzip_exact(*dicts):
    """
    Parameters
    ----------
    *dicts : iterable[dict]
        A sequence of dicts all sharing the same keys.

    Returns
    -------
    zipped : dict
        A dict whose keys are the union of all keys in *dicts, and whose values
        are tuples of length len(dicts) containing the result of looking up
        each key in each dict.

    Raises
    ------
    ValueError
        If dicts don't all have the same keys.

    Example
    -------
    >>> result = dzip_exact({'a': 1, 'b': 2}, {'a': 3, 'b': 4})
    >>> result == {'a': (1, 3), 'b': (2, 4)}
    True
    """
    if not same(*map(viewkeys, dicts)):
        raise ValueError(
            "dict keys not all equal:\n\n%s" % _format_unequal_keys(dicts)
        )
    return {k: tuple(d[k] for d in dicts) for k in dicts[0]}
assets.py 文件源码 项目:zipline-chinese 作者: zhanghan1990 项目源码 文件源码 阅读 41 收藏 0 点赞 0 评论 0
def _convert_asset_timestamp_fields(dict_):
    """
    Takes in a dict of Asset init args and converts dates to pd.Timestamps
    """
    for key in (_asset_timestamp_fields & viewkeys(dict_)):
        value = pd.Timestamp(dict_[key], tz='UTC')
        dict_[key] = None if isnull(value) else value
    return dict_
metadata.py 文件源码 项目:deb-python-cassandra-driver 作者: openstack 项目源码 文件源码 阅读 31 收藏 0 点赞 0 评论 0
def as_cql_query(self, formatted=False):
        """
        Returns a CQL query that can be used to recreate this function.
        If `formatted` is set to :const:`True`, extra whitespace will
        be added to make the query more readable.
        """
        sep = '\n    ' if formatted else ' '
        keyspace = protect_name(self.keyspace_name)
        name = protect_name(self.name)

        selected_cols = '*' if self.include_all_columns else ', '.join(protect_name(col.name) for col in self.columns.values())
        base_table = protect_name(self.base_table_name)
        where_clause = self.where_clause

        part_key = ', '.join(protect_name(col.name) for col in self.partition_key)
        if len(self.partition_key) > 1:
            pk = "((%s)" % part_key
        else:
            pk = "(%s" % part_key
        if self.clustering_key:
            pk += ", %s" % ', '.join(protect_name(col.name) for col in self.clustering_key)
        pk += ")"

        properties = TableMetadataV3._property_string(formatted, self.clustering_key, self.options)

        ret = "CREATE MATERIALIZED VIEW %(keyspace)s.%(name)s AS%(sep)s" \
               "SELECT %(selected_cols)s%(sep)s" \
               "FROM %(keyspace)s.%(base_table)s%(sep)s" \
               "WHERE %(where_clause)s%(sep)s" \
               "PRIMARY KEY %(pk)s%(sep)s" \
               "WITH %(properties)s" % locals()

        if self.extensions:
            registry = _RegisteredExtensionType._extension_registry
            for k in six.viewkeys(registry) & self.extensions:  # no viewkeys on OrderedMapSerializeKey
                ext = registry[k]
                cql = ext.after_table_cql(self, k, self.extensions[k])
                if cql:
                    ret += "\n\n%s" % (cql,)
        return ret
stat.py 文件源码 项目:plotnine 作者: has2k1 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def __init__(self, *args, **kwargs):
        kwargs = data_mapping_as_kwargs(args, kwargs)
        self._kwargs = kwargs  # Will be used to create the geom
        self.params = copy_keys(kwargs, deepcopy(self.DEFAULT_PARAMS))
        self.aes_params = {ae: kwargs[ae]
                           for ae in (self.aesthetics() &
                                      six.viewkeys(kwargs))}
stat.py 文件源码 项目:plotnine 作者: has2k1 项目源码 文件源码 阅读 39 收藏 0 点赞 0 评论 0
def use_defaults(self, data):
        """
        Combine data with defaults and set aesthetics from parameters

        stats should not override this method.

        Parameters
        ----------
        data : pandas.DataFrame
            Data used for drawing the geom.

        Returns
        -------
        out : pandas.DataFrame
            Data used for drawing the geom.
        """
        missing = (self.aesthetics() -
                   six.viewkeys(self.aes_params) -
                   set(data.columns))

        for ae in missing-self.REQUIRED_AES:
            if self.DEFAULT_AES[ae] is not None:
                data[ae] = self.DEFAULT_AES[ae]

        missing = (six.viewkeys(self.aes_params) -
                   set(data.columns))

        for ae in self.aes_params:
            data[ae] = self.aes_params[ae]

        return data
stat_ydensity.py 文件源码 项目:plotnine 作者: has2k1 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def setup_params(self, data):
        params = self.params.copy()

        valid_scale = ('area', 'count', 'width')
        if params['scale'] not in valid_scale:
            msg = "Parameter scale should be one of {}"
            raise PlotnineError(msg.format(valid_scale))

        lookup = {
            'biweight': 'biw',
            'cosine': 'cos',
            'cosine2': 'cos2',
            'epanechnikov': 'epa',
            'gaussian': 'gau',
            'triangular': 'tri',
            'triweight': 'triw',
            'uniform': 'uni'}

        with suppress(KeyError):
            params['kernel'] = lookup[params['kernel'].lower()]

        if params['kernel'] not in six.viewvalues(lookup):
            msg = ("kernel should be one of {}. "
                   "You may use the abbreviations {}")
            raise PlotnineError(msg.format(six.viewkeys(lookup),
                                           six.viewvalues()))

        missing_params = (six.viewkeys(stat_density.DEFAULT_PARAMS) -
                          six.viewkeys(params))
        for key in missing_params:
            params[key] = stat_density.DEFAULT_PARAMS[key]

        return params
geom.py 文件源码 项目:plotnine 作者: has2k1 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def aesthetics(cls):
        """
        Return all the aesthetics for this geom

        geoms should not override this method.
        """
        main = six.viewkeys(cls.DEFAULT_AES) | cls.REQUIRED_AES
        other = {'group'}
        # Need to recognize both spellings
        if 'color' in main:
            other.add('colour')
        if 'outlier_color' in main:
            other.add('outlier_colour')
        return main | other
geom.py 文件源码 项目:plotnine 作者: has2k1 项目源码 文件源码 阅读 38 收藏 0 点赞 0 评论 0
def use_defaults(self, data):
        """
        Combine data with defaults and set aesthetics from parameters

        geoms should not override this method.

        Parameters
        ----------
        data : pandas.DataFrame
            Data used for drawing the geom.

        Returns
        -------
        out : pandas.DataFrame
            Data used for drawing the geom.
        """
        missing_aes = (six.viewkeys(self.DEFAULT_AES) -
                       six.viewkeys(self.aes_params) -
                       set(data.columns))

        # Not in data and not set, use default
        for ae in missing_aes:
            data[ae] = self.DEFAULT_AES[ae]

        # If set, use it
        for ae, value in self.aes_params.items():
            try:
                data[ae] = value
            except ValueError:
                # sniff out the special cases, like custom
                # tupled linetypes, shapes and colors
                if is_valid_aesthetic(value, ae):
                    data[ae] = [value]*len(data)
                else:
                    msg = ("'{}' does not look like a "
                           "valid value for `{}`")
                    raise PlotnineError(msg.format(value, ae))

        return data
data_generator.py 文件源码 项目:feagen 作者: ianlini 项目源码 文件源码 阅读 36 收藏 0 点赞 0 评论 0
def __init__(self, handlers):
        handler_set = set(six.viewkeys(handlers))
        if handler_set != self._handler_set:
            redundant_handlers_set = handler_set - self._handler_set
            lacked_handlers_set = self._handler_set - handler_set
            raise ValueError('Handler set mismatch. {} redundant and {} lacked.'
                             .format(redundant_handlers_set,
                                     lacked_handlers_set))
        self._handlers = handlers
driver.py 文件源码 项目:zun 作者: openstack 项目源码 文件源码 阅读 41 收藏 0 点赞 0 评论 0
def update_containers_states(self, context, containers):
        db_containers = self.list(context)
        if not db_containers:
            return

        id_to_db_container_map = {container.container_id: container
                                  for container in db_containers}
        id_to_container_map = {container.container_id: container
                               for container in containers}

        for cid in (six.viewkeys(id_to_container_map) &
                    six.viewkeys(id_to_db_container_map)):
            container = id_to_container_map[cid]
            # sync status
            db_container = id_to_db_container_map[cid]
            if container.status != db_container.status:
                old_status = container.status
                container.status = db_container.status
                container.save(context)
                LOG.info('Status of container %s changed from %s to %s',
                         container.uuid, old_status, container.status)
            # sync host
            # Note(kiennt): Current host.
            cur_host = CONF.host
            if container.host != cur_host:
                old_host = container.host
                container.host = cur_host
                container.save(context)
                LOG.info('Host of container %s changed from %s to %s',
                         container.uuid, old_host, container.host)
config.py 文件源码 项目:sd2 作者: gae123 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def process_inheritance(config_dct, keys):
    def get_processed_dct(tlkey, host, hostsdict):
        rr = {}
        extends = host.get('extends', [])
        if isinstance(extends, six.string_types):
            extends = [extends]
        for extend in extends + [host['name']]:
            extend_host = hostsdict[extend]
            for key in six.viewkeys(extend_host):
                if key in six.viewkeys(rr) and isinstance(rr[key], list):
                    ehlst = (extend_host[key]
                        if isinstance(extend_host[key], (list,tuple))
                        else [extend_host[key]])
                    for val in ehlst:
                        if not val in rr[key]:
                            rr[key].append(val)
                else:
                    rr[key] = copy.deepcopy(extend_host[key])
        return rr

    for tlkey in keys:
        hostsdict = {x['name']: x for x in config_dct.get(tlkey, [])}
        dfsnodes = _dfs(config_dct.get(tlkey, []))
        #print [x['name'] for x in dfsnodes]
        rr = []
        for dct in dfsnodes:
            isabstract = dct.get('abstract')
            #disabled = dct.get('disabled')
            dct = get_processed_dct(tlkey, dct, hostsdict)
            for key in ['abstract', 'extends']:
                if dct.get(key) is not None:
                    del dct[key]
            #if not isabstract:
            #    process_expansions(dct)
            hostsdict[dct['name']] = dct
            if not isabstract:
                rr.append(dct)

        config_dct[tlkey] = rr
tests.py 文件源码 项目:django-requestlogging 作者: tarkatronic 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def test_find_loggers_with_filter(self):
        loggers = self.middleware.find_loggers_with_filter(RequestFilter)
        self.assertListEqual(list(six.viewkeys(loggers)), [self.logger])
        self.assertEqual([type(f) for f in loggers[self.logger]],
                         [RequestFilter],
                         loggers[self.logger])
cache.py 文件源码 项目:odoo-rpc-client 作者: katyukha 项目源码 文件源码 阅读 30 收藏 0 点赞 0 评论 0
def update_keys(self, keys):
        """ Add new IDs to cache.

            :param list keys: list of new IDs to be added to cache
            :return: self
            :rtype: ObjectCache
        """
        if not self:
            # for large amounts of data, this may be faster (no need for set
            # and difference calls)
            self.update({cid: {'id': cid} for cid in keys})
        else:
            self.update({cid: {'id': cid}
                         for cid in set(keys).difference(six.viewkeys(self))})
        return self
functional.py 文件源码 项目:catalyst 作者: enigmampc 项目源码 文件源码 阅读 31 收藏 0 点赞 0 评论 0
def dzip_exact(*dicts):
    """
    Parameters
    ----------
    *dicts : iterable[dict]
        A sequence of dicts all sharing the same keys.

    Returns
    -------
    zipped : dict
        A dict whose keys are the union of all keys in *dicts, and whose values
        are tuples of length len(dicts) containing the result of looking up
        each key in each dict.

    Raises
    ------
    ValueError
        If dicts don't all have the same keys.

    Examples
    --------
    >>> result = dzip_exact({'a': 1, 'b': 2}, {'a': 3, 'b': 4})
    >>> result == {'a': (1, 3), 'b': (2, 4)}
    True
    """
    if not same(*map(viewkeys, dicts)):
        raise ValueError(
            "dict keys not all equal:\n\n%s" % _format_unequal_keys(dicts)
        )
    return {k: tuple(d[k] for d in dicts) for k in dicts[0]}
assets.py 文件源码 项目:catalyst 作者: enigmampc 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def _convert_asset_timestamp_fields(dict_):
    """
    Takes in a dict of Asset init args and converts dates to pd.Timestamps
    """
    for key in _asset_timestamp_fields & viewkeys(dict_):
        value = pd.Timestamp(dict_[key], tz='UTC')
        dict_[key] = None if isnull(value) else value
    return dict_
exp4p.py 文件源码 项目:striatum 作者: ntucllab 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def _exp4p_score(self, context):
        """The main part of Exp4.P.
        """
        advisor_ids = list(six.viewkeys(context))

        w = self._modelstorage.get_model()['w']
        if len(w) == 0:
            for i in advisor_ids:
                w[i] = 1
        w_sum = sum(six.viewvalues(w))

        action_probs_list = []
        for action_id in self.action_ids:
            weighted_exp = [w[advisor_id] * context[advisor_id][action_id]
                            for advisor_id in advisor_ids]
            prob_vector = np.sum(weighted_exp) / w_sum
            action_probs_list.append((1 - self.n_actions * self.p_min)
                                     * prob_vector
                                     + self.p_min)
        action_probs_list = np.asarray(action_probs_list)
        action_probs_list /= action_probs_list.sum()

        estimated_reward = {}
        uncertainty = {}
        score = {}
        for action_id, action_prob in zip(self.action_ids, action_probs_list):
            estimated_reward[action_id] = action_prob
            uncertainty[action_id] = 0
            score[action_id] = action_prob
        self._modelstorage.save_model(
            {'action_probs': estimated_reward, 'w': w})

        return estimated_reward, uncertainty, score


问题


面经


文章

微信
公众号

扫码关注公众号