python类argwhere()的实例源码

distances.py 文件源码 项目:coordinates 作者: markovmodel 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def transform(self, traj):
        # All needed distances
        Dall = mdtraj.compute_distances(traj, self.distance_indexes, periodic=self.periodic)
        # Just the minimas
        Dmin = np.zeros((traj.n_frames,self.dimension))
        res = np.zeros_like(Dmin)
        # Compute the min groupwise
        for ii, (gi, gf) in enumerate(self.group_identifiers):
            Dmin[:, ii] = Dall[:,gi:gf].min(1)
        # Do we want binary?
        if self.threshold is not None:
            I = np.argwhere(Dmin <= self.threshold)
            res[I[:, 0], I[:, 1]] = 1.0
        else:
            res = Dmin

        return res
lc.py 文件源码 项目:PyCS 作者: COSMOGRAIL 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def maskinfo(self):
        """
        Returns a description of the masked points and available properties of them.
        Note that the output format can be directly used as a skiplist.
        """

        cps = self.commonproperties()
        lines = []
        maskindices = np.argwhere(self.mask == False)
        for maskindex in maskindices:
            comment = ", ".join(["%s : %s" % (cp, self.properties[maskindex][cp]) for cp in cps])
            txt = "%.1f    %s" % (self.jds[maskindex], comment)
            lines.append(txt)

        txt = "\n".join(lines)
        txt = "# %i Masked points of %s :\n" % (np.sum(self.mask == False), str(self)) + txt
        return txt
find_null.py 文件源码 项目:impyute 作者: eltonlaw 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def find_null(data):
    """ Finds the indices of all missing values.

    Parameters
    ----------
    data: numpy.ndarray
        Data to impute.

    Returns
    -------
    List of tuples
        Indices of all missing values in tuple format; (i, j)

    """
    null_xy = np.argwhere(np.isnan(data))
    return null_xy
eval_helpers.py 文件源码 项目:poseval 作者: leonid-pishchulin 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def VOCap(rec,prec):

    mpre = np.zeros([1,2+len(prec)])
    mpre[0,1:len(prec)+1] = prec
    mrec = np.zeros([1,2+len(rec)])
    mrec[0,1:len(rec)+1] = rec
    mrec[0,len(rec)+1] = 1.0

    for i in range(mpre.size-2,-1,-1):
        mpre[0,i] = max(mpre[0,i],mpre[0,i+1])

    i = np.argwhere( ~np.equal( mrec[0,1:], mrec[0,:mrec.shape[1]-1]) )+1
    i = i.flatten()

    # compute area under the curve
    ap = np.sum( np.multiply( np.subtract( mrec[0,i], mrec[0,i-1]), mpre[0,i] ) )

    return ap
preprocess.py 文件源码 项目:plda 作者: RaviSoji 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def get_principal_components(flattened_images, n_components='default',
                             default_pct_variance_explained=.96):
    """ Standardizes the data and gets the principal components.
    """
    for img in flattened_images:
        assert isinstance(img, np.ndarray)
        assert img.shape == flattened_images[-1].shape
        assert len(img.shape) == 1
    X = np.asarray(flattened_images)
    X -= X.mean(axis=0)  # Center all of the data around the origin.
    X /= np.std(X, axis=0)

    pca = PCA()
    pca.fit(X)

    if n_components == 'default':
        sorted_eig_vals = pca.explained_variance_
        cum_pct_variance = (sorted_eig_vals / sorted_eig_vals.sum()).cumsum()
        idxs = np.argwhere(cum_pct_variance >= default_pct_variance_explained)
        n_components = np.squeeze(idxs)[0]

    V = pca.components_[:n_components + 1, :].T
    principal_components = np.matmul(X, V)

    return principal_components
FQI.py 文件源码 项目:TreasureBot 作者: SamuelePolimi 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def fit(self,s,a,r,s_next):

        if not self.first_time:
            q_next = self.Q[0].predict(s_next)
            q_next = q_next.reshape((q_next.shape[0],1))
            for a_i in range(1,self.n_action):
                q_next = np.concatenate((q_next,self.Q[a_i].predict(s_next).reshape((q_next.shape[0],1))), axis=1)
            q_max = np.max(q_next, axis=1)
            for a_i in range(self.n_action):
                indx = np.argwhere(a==a_i)

                y = r[indx].ravel()+ 1 * q_max[indx].ravel() #+ self.gamma * q_max[indx].ravel()
                self.Q[a_i].fit(s[indx.ravel(),:],y)
        else:
            for a_i in range(self.n_action):
                indx = np.argwhere(a == a_i)
                y = r[indx]
                self.Q[a_i].fit(s[indx.ravel(), :], y.ravel())
            self.first_time = False
space_invaders.py 文件源码 项目:human-rl 作者: gsastry 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def ship_location(image):
    is_ship = np.sum(np.abs(image[185, :, :] - SHIP_COLOR), axis=1) == 0
    w = np.argwhere(is_ship)
    return w[0][0] if len(w) == 1 else None
space_invaders.py 文件源码 项目:human-rl 作者: gsastry 项目源码 文件源码 阅读 18 收藏 0 点赞 0 评论 0
def ship_location(image):
    is_ship = np.sum(np.abs(image[185, :, :] - SHIP_COLOR), axis=1) == 0
    w = np.argwhere(is_ship)
    return w[0][0] if len(w) == 1 else None
space_invaders.py 文件源码 项目:human-rl 作者: gsastry 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def ship_location(image):
    is_ship = np.sum(np.abs(image[185, :, :] - SHIP_COLOR), axis=1) == 0
    w = np.argwhere(is_ship)
    return w[0][0] if len(w) == 1 else None
space_invaders.py 文件源码 项目:human-rl 作者: gsastry 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def ship_location(image):
    is_ship = np.sum(np.abs(image[185, :, :] - SHIP_COLOR), axis=1) == 0
    w = np.argwhere(is_ship)
    return w[0][0] if len(w) == 1 else None
space_invaders.py 文件源码 项目:human-rl 作者: gsastry 项目源码 文件源码 阅读 18 收藏 0 点赞 0 评论 0
def ship_location(image):
    is_ship = np.sum(np.abs(image[185, :, :] - SHIP_COLOR), axis=1) == 0
    w = np.argwhere(is_ship)
    return w[0][0] if len(w) == 1 else None
Filters.py 文件源码 项目:NeoAnalysis 作者: neoanalysis 项目源码 文件源码 阅读 28 收藏 0 点赞 0 评论 0
def adjustXPositions(self, pts, data):
        """Return a list of Point() where the x position is set to the nearest x value in *data* for each point in *pts*."""
        points = []
        timeIndices = []
        for p in pts:
            x = np.argwhere(abs(data - p.x()) == abs(data - p.x()).min())
            points.append(Point(data[x], p.y()))
            timeIndices.append(x)

        return points, timeIndices
kwikio.py 文件源码 项目:NeoAnalysis 作者: neoanalysis 项目源码 文件源码 阅读 19 收藏 0 点赞 0 评论 0
def read_spiketrain(self, cluster_id, model,
                        lazy=False,
                        cascade=True,
                        get_waveforms=True,
                        ):
        """
        Reads sorted spiketrains

        Parameters:
        get_waveforms: bool, default = False
            Wether or not to get the waveforms
        cluster_id: int,
            Which cluster to load, according to cluster id from klusta
        model: klusta.kwik.KwikModel
            A KwikModel object obtained by klusta.kwik.KwikModel(fname)
        """
        try:
            if ((not(cluster_id in model.cluster_ids))):
                raise ValueError
        except ValueError:
                print("Exception: cluster_id (%d) not found !! " % cluster_id)
                return
        clusters = model.spike_clusters
        idx = np.argwhere(clusters == cluster_id)
        if get_waveforms:
            w = model.all_waveforms[idx]
            # klusta: num_spikes, samples_per_spike, num_chans = w.shape
            w = w.swapaxes(1, 2)
        else:
            w = None
        sptr = SpikeTrain(times=model.spike_times[idx],
                          t_stop=model.duration, waveforms=w, units='s',
                          sampling_rate=model.sample_rate*pq.Hz,
                          file_origin=self.filename,
                          **{'cluster_id': cluster_id})
        return sptr
Filters.py 文件源码 项目:NeoAnalysis 作者: neoanalysis 项目源码 文件源码 阅读 27 收藏 0 点赞 0 评论 0
def adjustXPositions(self, pts, data):
        """Return a list of Point() where the x position is set to the nearest x value in *data* for each point in *pts*."""
        points = []
        timeIndices = []
        for p in pts:
            x = np.argwhere(abs(data - p.x()) == abs(data - p.x()).min())
            points.append(Point(data[x], p.y()))
            timeIndices.append(x)

        return points, timeIndices
kwikio.py 文件源码 项目:NeoAnalysis 作者: neoanalysis 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def read_spiketrain(self, cluster_id, model,
                        lazy=False,
                        cascade=True,
                        get_waveforms=True,
                        ):
        """
        Reads sorted spiketrains

        Parameters:
        get_waveforms: bool, default = False
            Wether or not to get the waveforms
        cluster_id: int,
            Which cluster to load, according to cluster id from klusta
        model: klusta.kwik.KwikModel
            A KwikModel object obtained by klusta.kwik.KwikModel(fname)
        """
        try:
            if ((not(cluster_id in model.cluster_ids))):
                raise ValueError
        except ValueError:
                print("Exception: cluster_id (%d) not found !! " % cluster_id)
                return
        clusters = model.spike_clusters
        idx = np.argwhere(clusters == cluster_id)
        if get_waveforms:
            w = model.all_waveforms[idx]
            # klusta: num_spikes, samples_per_spike, num_chans = w.shape
            w = w.swapaxes(1, 2)
        else:
            w = None
        sptr = SpikeTrain(times=model.spike_times[idx],
                          t_stop=model.duration, waveforms=w, units='s',
                          sampling_rate=model.sample_rate*pq.Hz,
                          file_origin=self.filename,
                          **{'cluster_id': cluster_id})
        return sptr
gridworld_base.py 文件源码 项目:BlueWhale 作者: caffe2 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def reset(self):
        self._state = self._index(np.argwhere(self.grid == S)[0])
        return self._state
mcmc_func.py 文件源码 项目:CRN_ProbabilisticInversion 作者: elaloy 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def CalcDelta(nCR,delta_tot,delta_normX,CR):
    # Calculate total normalized Euclidean distance for each crossover value

    # Derive sum_p2 for each different CR value 
    for zz in range(0,nCR):

        # Find which chains are updated with zz/MCMCPar.nCR
        idx = np.argwhere(CR==(1.0+zz)/nCR);idx=idx[:,0]

        # Add the normalized squared distance tot the current delta_tot;
        delta_tot[0,zz] = delta_tot[0,zz] + np.sum(delta_normX[idx])

    return delta_tot
s3_geotiff_rdd_test.py 文件源码 项目:geopyspark 作者: locationtech-labs 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def assertTilesEqual(self, a, b):
        """compare two numpy arrays that are tiles"""
        self.assertEqual(a.shape, b.shape)
        cmp = (a == b)  # result must be array of matching cells
        diff = np.argwhere(cmp == False)
        if np.size(diff) > 0:
            raise Exception("Tiles differ at: ", np.size(diff), diff)
        return True
QLearning.py 文件源码 项目:ProbablisticRobotics2016 作者: RyuYamamoto 项目源码 文件源码 阅读 29 收藏 0 点赞 0 评论 0
def learn(self):
        y, x = self.state
    current_acton_list = copy.deepcopy(self.action_list[y,x])
    if np.random.rand() > self.epsilon:
            max_q = self.q[current_acton_list,y,x].max()
        action_list_index = list(np.argwhere(self.q[current_acton_list,y,x] == max_q))
        random.shuffle(action_list_index)
            action = current_acton_list[action_list_index[0]]
    else:
        random.shuffle(current_acton_list)
            action = current_acton_list[0]
        move = self.move_list.get(action)
    self.update_q(action, move)
        self.q_value_list.append(self.q_max_value(move))
        self.state += move
test.py 文件源码 项目:tensorflow_ocr 作者: BowieHsu 项目源码 文件源码 阅读 65 收藏 0 点赞 0 评论 0
def pixel_detect(score_map, geo_map, score_map_thresh=0.8, link_thresh=0.8):
    '''
    restore text boxes from score map and geo map
    :param score_map:
    :param geo_map:
    :param timer:
    :param score_map_thresh: threshhold for score map
    :param box_thresh: threshhold for boxes
    :param nms_thres: threshold for nms
    :return:
    '''
    if len(score_map.shape) == 4:
        score_map = score_map[0, :, :, 0]
        geo_map = geo_map[0, :, :, ]

    # filter the score map
    res_map = np.zeros((score_map.shape[0] ,score_map.shape[1] ))
    xy_text = np.argwhere(score_map > score_map_thresh)

    for p in xy_text:
        res_map[p[0], p[1]] = 1

    res = res_map

    for i in range(8):
        geo_map_split = geo_map[:,:,i * 2 + 1]
        link_text = np.argwhere(geo_map_split < link_thresh)
        res[link_text[0], link_text[1]] = 0

    return np.array(res_map, dtype=np.uint8)


问题


面经


文章

微信
公众号

扫码关注公众号