TargetDataSampler.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:bnpy 作者: bnpy 项目源码 文件源码
def _sample_target_GroupXData(Data, model, LP, **kwargs):
    ''' Draw sample subset of provided GroupXData dataset
    '''
    randstate = kwargs['randstate']
    if not hasValidKey('targetCompID', kwargs):
        raise NotImplementedError('TODO')

    ktarget = kwargs['targetCompID']
    targetProbThr = kwargs['targetCompFrac']
    mask = LP['resp'][:, ktarget] > targetProbThr
    objIDs = np.flatnonzero(mask)
    if len(objIDs) < 2:
        return None, dict()
    randstate.shuffle(objIDs)
    targetObjIDs = objIDs[:kwargs['targetMaxSize']]
    TargetData = Data.select_subset_by_mask(atomMask=targetObjIDs,
                                            doTrackFullSize=False)
    TargetInfo = dict(ktarget=ktarget)
    return TargetData, TargetInfo
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号