python类tqdm()的实例源码

test_connectionpools.py 文件源码 项目:cloud-volume 作者: seung-lab 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def test_gc_stresstest():
  with Storage('gs://seunglab-test/cloudvolume/connection_pool/', n_threads=0) as stor:
    stor.put_file('test', 'some string')

  n_trials = 500
  pbar = tqdm(total=n_trials)

  @retry
  def create_conn(interface):
    # assert GC_POOL.total_connections() <= GC_POOL.max_connections * 5
    bucket = GC_POOL.get_connection()
    blob = bucket.get_blob('cloudvolume/connection_pool/test')
    blob.download_as_string()
    GC_POOL.release_connection(bucket)
    pbar.update()

  with ThreadedQueue(n_threads=20) as tq:
    for _ in range(n_trials):
      tq.put(create_conn)

  pbar.close()
build_feature_files.py 文件源码 项目:human-rl 作者: gsastry 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def action_label_counts(directory, data_loader, n_actions=18, n=None):
    episode_paths = frame.episode_paths(directory)
    label_counts = [0, 0]
    action_label_counts = [[0, 0] for i in range(n_actions)]
    if n is not None:
        np.random.shuffle(episode_paths)
        episode_paths = episode_paths[:n]
    for episode_path in tqdm.tqdm(episode_paths):
        try:
            features, labels = data_loader.load_features_and_labels([episode_path])
        except:
            traceback.print_exc()
        else:
            for label in range(len(label_counts)):
                label_counts[label] += np.count_nonzero(labels == label)
                for action in range(n_actions):
                    actions = np.reshape(np.array(features["action"]), [-1])
                    action_label_counts[action][label] += np.count_nonzero(
                        np.logical_and(labels == label, actions == action))
    return label_counts, action_label_counts
extract.py 文件源码 项目:twentybn-dl 作者: TwentyBN 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def extract_chunks(files, num_images, out_path):
    with tqdm(total=num_images,
              unit='images',
              ncols=80,
              unit_scale=True) as pbar:
        process = tar(cat(files, _piped=True), 'xvz', _iter=True, _cwd=out_path)

        def kill():
            try:
                process.kill()
            except:
                pass
        atexit.register(kill)
        for line in process:
            if line.strip().endswith('.jpg'):
                pbar.update(1)
test_connectionpools.py 文件源码 项目:cloud-volume 作者: seung-lab 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def test_s3_stresstest():
  with Storage('s3://seunglab-test/cloudvolume/connection_pool/', n_threads=0) as stor:
    stor.put_file('test', 'some string')

  n_trials = 500
  pbar = tqdm(total=n_trials)

  @retry
  def create_conn(interface):
    conn = S3_POOL.get_connection()  
    # assert S3_POOL.total_connections() <= S3_POOL.max_connections * 5
    bucket = conn.get_object(
      Bucket='seunglab-test',
      Key='cloudvolume/connection_pool/test',
    )
    S3_POOL.release_connection(conn)
    pbar.update()

  with ThreadedQueue(n_threads=20) as tq:
    for _ in range(n_trials):
      tq.put(create_conn)

  pbar.close()
lcproc.py 文件源码 项目:astrobase 作者: waqasbhatti 项目源码 文件源码 阅读 24 收藏 0 点赞 0 评论 0
def serial_varfeatures(lclist,
                       outdir,
                       maxobjects=None,
                       timecols=None,
                       magcols=None,
                       errcols=None,
                       mindet=1000,
                       lcformat='hat-sql',
                       nworkers=None):

    if maxobjects:
        lclist = lclist[:maxobjects]

    tasks = [(x, outdir, timecols, magcols, errcols, mindet, lcformat)
             for x in lclist]

    for task in tqdm(tasks):
        result = varfeatures_worker(task)
train.py 文件源码 项目:DeepWorks 作者: daigo0927 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def valid(self, batch_size = 128, weights_file = None):

        if weights_file is not None:
            self.saver.restore(self.sess, weights_file)

        data_size = self.x_test.shape[0]
        num_batches = int(data_size/batch_size)

        acc_vals = []
        permute_idx = np.random.permutation(np.arange(data_size))
        for b in tqdm(np.arange(num_batches)):
            x_val = self.x_test[permute_idx[b*batch_size:(b+1)*batch_size]]
            y_val = self.y_test[permute_idx[b*batch_size:(b+1)*batch_size]]

            acc_val = self.sess.run(self.accuracy,
                                    feed_dict = {self.images:x_val, self.labels:y_val})
            acc_vals.append(acc_val)

        print('validation accuracy : {}'.format(np.mean(acc_vals)))
simplehttp.py 文件源码 项目:skymod 作者: DelusionalLogic 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def download_file(self, name, url, headers, filename):
        r = super().getSession().get(
            url,
            allow_redirects=True,
            headers=headers,
            stream=True
        )

        if r.status_code != 200:
            raise RuntimeError(
                "Failed downloading file due to non 200 return code. "
                "Return code was " + str(r.status_code)
            )

        total_size = int(r.headers.get("content-length", 0))
        with tqdm(desc=name, total=total_size, unit='B',
                  unit_scale=True, miniters=1) as bar:
            with open(filename, 'wb') as fd:
                for chunk in r.iter_content(32*1024):
                    bar.update(len(chunk))
                    fd.write(chunk)
gempro.py 文件源码 项目:ssbio 作者: SBRG 项目源码 文件源码 阅读 31 收藏 0 点赞 0 评论 0
def pdb_downloader_and_metadata(self, outdir=None, pdb_file_type=None, force_rerun=False):
        """Download ALL mapped experimental structures to each protein's structures directory.

        Args:
            outdir (str): Path to output directory, if GEM-PRO directories were not set or other output directory is
                desired
            pdb_file_type (str): Type of PDB file to download, if not already set or other format is desired
            force_rerun (bool): If files should be re-downloaded if they already exist

        """

        if not pdb_file_type:
            pdb_file_type = self.pdb_file_type

        counter = 0
        for g in tqdm(self.genes):
            pdbs = g.protein.pdb_downloader_and_metadata(outdir=outdir, pdb_file_type=pdb_file_type, force_rerun=force_rerun)

            if pdbs:
                counter += len(pdbs)

        log.info('Updated PDB metadata dataframe. See the "df_pdb_metadata" attribute for a summary dataframe.')
        log.info('Saved {} structures total'.format(counter))
GeneBot.py 文件源码 项目:scheduled-bots 作者: SuLab 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def run(self, records, total=None, fast_run=True, write=True):
        # this shouldn't ever actually get used now
        raise ValueError()
        records = self.filter(records)
        for record in tqdm(records, mininterval=2, total=total):
            gene = self.GENE_CLASS(record, self.organism_info, self.login)
            try:
                gene.create_item(fast_run=fast_run, write=write)
            except Exception as e:
                exc_info = sys.exc_info()
                traceback.print_exception(*exc_info)
                msg = wdi_helpers.format_msg(gene.external_ids['Entrez Gene ID'], PROPS['Entrez Gene ID'], None,
                                             str(e), msg_type=type(e))
                wdi_core.WDItemEngine.log("ERROR", msg)
                gene.status = msg

            if gene.status is not True:
                self.failed.append(gene.entrez)
GeneBot.py 文件源码 项目:scheduled-bots 作者: SuLab 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def cleanup(self, releases, last_updated):
        """

        :param releases:
        :param last_updated:
        :param failed: list of entrez ids to skip
        :return:
        """
        print(self.failed)
        entrez_qid = wdi_helpers.id_mapper('P351', ((PROPS['found in taxon'], self.organism_info['wdid']),))
        print(len(entrez_qid))
        entrez_qid = {entrez: qid for entrez, qid in entrez_qid.items() if entrez not in self.failed}
        print(len(entrez_qid))
        filter = {PROPS['Entrez Gene ID']: '', PROPS['found in taxon']: self.organism_info['wdid']}
        frc = FastRunContainer(wdi_core.WDBaseDataType, wdi_core.WDItemEngine, base_filter=filter, use_refs=True)
        frc.clear()
        for qid in tqdm(entrez_qid.values()):
            remove_deprecated_statements(qid, frc, releases, last_updated, list(PROPS.values()), self.login)
GeneBot.py 文件源码 项目:scheduled-bots 作者: SuLab 项目源码 文件源码 阅读 26 收藏 0 点赞 0 评论 0
def run(self, records, total=None, fast_run=True, write=True):
        records = self.filter(records)
        for record in tqdm(records, mininterval=2, total=total):
            # print(record['entrezgene'])
            gene = self.GENE_CLASS(record, self.organism_info, self.chr_num_wdid, self.login)
            try:
                gene.create_item(fast_run=fast_run, write=write)
            except Exception as e:
                exc_info = sys.exc_info()
                traceback.print_exception(*exc_info)
                msg = wdi_helpers.format_msg(gene.external_ids['Entrez Gene ID'], PROPS['Entrez Gene ID'], None,
                                             str(e), msg_type=type(e))
                wdi_core.WDItemEngine.log("ERROR", msg)
                gene.status = msg
            if gene.status is not True:
                self.failed.append(gene.entrez)
ProteinBot.py 文件源码 项目:scheduled-bots 作者: SuLab 项目源码 文件源码 阅读 25 收藏 0 点赞 0 评论 0
def run(self, records, total=None, fast_run=True, write=True):
        records = self.filter(records)
        for record in tqdm(records, mininterval=2, total=total):
            entrez_gene = str(record['entrezgene']['@value'])
            if entrez_gene not in self.gene_wdid_mapping:
                wdi_core.WDItemEngine.log("WARNING", format_msg(entrez_gene, "P351", None,
                                                                "Gene item not found during protein creation", None))
                continue
            gene_wdid = self.gene_wdid_mapping[entrez_gene]

            # handle multiple protiens
            if 'uniprot' in record and 'Swiss-Prot' in record['uniprot']['@value']:
                uniprots = record['uniprot']['@value']['Swiss-Prot']
                for uniprot in uniprots:
                    record['uniprot']['@value']['Swiss-Prot'] = uniprot
                    self.run_one(record, gene_wdid, write)
            else:
                self.run_one(record, gene_wdid, write)
tracker.py 文件源码 项目:scheduled-bots 作者: SuLab 项目源码 文件源码 阅读 32 收藏 0 点赞 0 评论 0
def lookupLabels(changes):
        pids = set(s.pid for s in changes)
        qids = set(s.qid for s in changes)
        values = set(s.value for s in changes if s.value and PROP_TYPE.get(s.pid) == "WikibaseItem")
        ref_qids = set(chain(*[
            [s['value'] for s in change.ref_list if s['value'] and PROP_TYPE.get(s['prop']) == "WikibaseItem"]
            for change in changes]))
        ref_pids = set(chain(*[[s['prop'] for s in change.ref_list] for change in changes]))
        labels = dict()
        x = pids | qids | values | ref_qids | ref_pids
        x = set(y for y in x if y)
        for chunk in tqdm(chunks(x, 500), total=len(x) / 500):
            l = getConceptLabels(tuple(chunk))
            labels.update(l)

        for c in changes:
            if c.pid and c.pid in labels:
                c.pid_label = labels[c.pid]
            if c.qid and c.qid in labels:
                c.qid_label = labels[c.qid]
            if c.value and c.value in labels:
                c.value_label = labels[c.value]
            for ref in c.ref_list:
                ref['value_label'] = labels.get(ref['value'], '')
                ref['prop_label'] = labels.get(ref['prop'], '')
tracker.py 文件源码 项目:scheduled-bots 作者: SuLab 项目源码 文件源码 阅读 40 收藏 0 点赞 0 评论 0
def get_revisions_past_weeks(qids, weeks):
    """
    Get the revision IDs for revisions on `qids` items in the past `weeks` weeks
    :param qids: set of qids
    :param weeks: int
    :return:
    """
    revisions = set()
    qids_str = '"' + '","'.join(qids) + '"'
    for week in tqdm(range(weeks)):
        query = '''select rev_id, rev_page, rev_timestamp, page_id, page_namespace, page_title, page_touched FROM revision
                           inner join page on revision.rev_page = page.page_id WHERE
                           rev_timestamp > DATE_FORMAT(DATE_SUB(DATE_SUB(NOW(),INTERVAL {week} WEEK), INTERVAL 1 WEEK),'%Y%m%d%H%i%s') AND
                           rev_timestamp < DATE_FORMAT(DATE_SUB(NOW(), INTERVAL {week} WEEK),'%Y%m%d%H%i%s') AND
                           page_content_model = "wikibase-item" AND
                           page.page_title IN({qids});
                    '''.format(qids=qids_str, week=week)
        revision_df = query_wikidata_mysql(query)
        print(len(revision_df))
        print(revision_df.head(2))
        print(revision_df.tail(2))
        revisions.update(set(revision_df.rev_id))
    return revisions
bot.py 文件源码 项目:scheduled-bots 作者: SuLab 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def main(chebi_iedb_map, log_dir="./logs", fast_run=False, write=True):
    login = wdi_login.WDLogin(user=WDUSER, pwd=WDPASS)
    wdi_core.WDItemEngine.setup_logging(log_dir=log_dir, logger_name='WD_logger', log_name=log_name,
                                        header=json.dumps(__metadata__))

    chebi_qid_map = id_mapper(PROPS['ChEBI-ID'])

    for chebi, iedb in tqdm(chebi_iedb_map.items()):
        if chebi not in chebi_qid_map:
            msg = wdi_helpers.format_msg(iedb, PROPS['IEDB Epitope ID'], None, "ChEBI:{} not found".format(chebi), "ChEBI not found")
            print(msg)
            wdi_core.WDItemEngine.log("WARNING", msg)
            continue
        s = [wdi_core.WDExternalID(iedb, PROPS['IEDB Epitope ID'], references=create_references(iedb))]
        item = wdi_core.WDItemEngine(wd_item_id=chebi_qid_map[chebi], data=s, domain="drugs", fast_run=fast_run,
                                     fast_run_base_filter={PROPS['ChEBI-ID']: ''}, fast_run_use_refs=True,
                                     ref_handler=ref_handlers.update_retrieved_if_new, global_ref_mode="CUSTOM")
        wdi_helpers.try_write(item, iedb, PROPS['IEDB Epitope ID'], login, edit_summary="Add IEDB Epitope ID",
                              write=write)
ProteinBot.py 文件源码 项目:scheduled-bots 作者: SuLab 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def create_uniprot_relationships(login, release_wdid, collection, taxon=None, write=True, run_one=False):
    # only do uniprot proteins that are already in wikidata
    # returns list of qids of items that are modified or skipped (excluding created)
    if taxon:
        uniprot2wd = wdi_helpers.id_mapper(UNIPROT, (("P703", taxon),))
        fast_run_base_filter = {UNIPROT: "", "P703": taxon}
    else:
        uniprot2wd = wdi_helpers.id_mapper(UNIPROT)
        fast_run_base_filter = {UNIPROT: ""}

    cursor = collection.find({'_id': {'$in': list(uniprot2wd.keys())}}).batch_size(20)
    qids = []
    for n, doc in tqdm(enumerate(cursor), total=cursor.count(), mininterval=10.0):
        wd_item = create_for_one_protein(login, doc, release_wdid, uniprot2wd, fast_run_base_filter, write=write)
        if wd_item and not wd_item.create_new_item:
            qids.append(wd_item.wd_item_id)
        if run_one:
            break
    return qids
main.py 文件源码 项目:SharesData 作者: xjkj123 项目源码 文件源码 阅读 31 收藏 0 点赞 0 评论 0
def UpDataShare():
    thread = []
    MaxThread = 3
    num=0
    code = Tools().GetShareCode()
    for x in code:
        y = threading.Thread(target=ChildThead, args=(x,))
        thread.append(y)
    try:
        for t in tqdm(thread):
            t.start()
            while True:
                time.sleep(0.05)
                if len(threading.enumerate()) < MaxThread:
                    if len(code) - num < 13:
                        t.join()
                    num = num + 1
                    break
    except:
        print "1223"
imdb_crawl.py 文件源码 项目:holcrawl 作者: shaypal5 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def crawl_by_file(file_path, verbose, year=None):
    """Crawls IMDB and builds movie profiles for a movies in the given file."""
    results = {res_type : 0 for res_type in _result.ALL_TYPES}
    titles = _titles_from_file(file_path)
    if verbose:
        print("Crawling over all {} IMDB movies in {}...".format(
            len(titles), file_path))
    movie_pbar = tqdm(titles, miniters=1, maxinterval=0.0001,
                      mininterval=0.00000000001, total=len(titles))
    for title in movie_pbar:
        res = crawl_by_title(title, verbose, year, movie_pbar)
        results[res] += 1
    print("{} IMDB movie profiles crawled.".format(len(titles)))
    for res_type in _result.ALL_TYPES:
        print('{} {}.'.format(results[res_type], res_type))


# === uniting movie profiles to csv ===
dataset.py 文件源码 项目:holcrawl 作者: shaypal5 项目源码 文件源码 阅读 31 收藏 0 点赞 0 评论 0
def build_united_profiles(verbose):
    """Build movie profiles with data from all resources."""
    os.makedirs(_UNITED_DIR_PATH, exist_ok=True)
    prof_names = sorted(_prof_names_in_all_resources())
    if verbose:
        print("Building movie profiles with data from all resources.")
        prof_names = tqdm(prof_names)
    for prof_name in prof_names:
        file_name = prof_name + '.json'
        imdb_prof_path = os.path.join(_IMDB_DIR_PATH, file_name)
        with open(imdb_prof_path, 'r') as imbd_prof_file:
            imdb_prof = json.load(imbd_prof_file)
        meta_prof_path = os.path.join(_METACRITIC_DIR_PATH, file_name)
        with open(meta_prof_path, 'r') as meta_prof_file:
            meta_prof = json.load(meta_prof_file)
        united_prof = {**imdb_prof, **meta_prof}
        united_prof_fpath = os.path.join(_UNITED_DIR_PATH, file_name)
        with open(united_prof_fpath, 'w+') as unite_prof_file:
            json.dump(united_prof, unite_prof_file, indent=2, sort_keys=True)
progress_bar.py 文件源码 项目:pyrsss 作者: butala 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def tqdm_callback(N, notebook=True):
    """
    Return a :module:`tqdm` progress bar expecting *N* iterations,
    either suitable with jupyter if *notebook* is true and for the
    terminal otherwise. The progress bar includes an additional method
    :function:`callback` (function of one ignored parameter) meant to
    be past as a callback function called to update the progress bar.
    """
    if notebook:
        progress_bar = tqdm.tqdm_notebook(total=N)
    else:
        progress_bar = tqdm.tqdm(total=N)
    def callback(self, i):
        self.update()
    progress_bar.callback = partial(callback, progress_bar)
    return progress_bar
utils.py 文件源码 项目:deeppavlov 作者: deepmipt 项目源码 文件源码 阅读 22 收藏 0 点赞 0 评论 0
def _generate_all_features(self):
        """
            generates all features for all mentions
            and frees from memory: self.embeddings and self.features

            pregenerate all feature vectors to increase get_batch speed
        """
        print('DataLoader: generating all features')
        # self.mention_features = {m: self._make_mention_features(m) for ms in self.document_mentions for m in ms}
        assert self.embeddings is not None
        assert self.features is not None

        for ms in tqdm(self.document_mentions):
            for m in ms:
                self.mention_features[m] = self._make_mention_features(m)

        self.features_size = len(self.mention_features[m])
        self.embeddings = None
        print('DataLoader: generating all features finished')
extract_database.py 文件源码 项目:wurst 作者: IndEcol 项目源码 文件源码 阅读 23 收藏 0 点赞 0 评论 0
def add_input_info_for_external_exchanges(activities, names):
    """Add details on exchange inputs from other databases"""
    names = set(names)
    cache = {}

    for ds in tqdm(activities):
        for exc in ds['exchanges']:
            if 'input' not in exc or exc['input'][0] in names:
                continue
            if exc['input'] not in cache:
                cache[exc['input']] = ActivityDataset.get(
                    ActivityDataset.database == exc['input'][0],
                    ActivityDataset.code == exc['input'][1],
                )
            obj = cache[exc['input']]
            exc['name'] = obj.name
            exc['product'] = obj.product
            exc['unit'] = obj.data['unit']
            exc['location'] = obj.location
            if exc['type'] == 'biosphere':
                exc['categories'] = obj.data['categories']
ubuntudata.py 文件源码 项目:hadan-gcloud 作者: youkpan 项目源码 文件源码 阅读 31 收藏 0 点赞 0 评论 0
def __init__(self, dirName):
        """
        Args:
            dirName (string): directory where to load the corpus
        """
        self.MAX_NUMBER_SUBDIR = 10
        self.conversations = []
        __dir = os.path.join(dirName, "dialogs")
        number_subdir = 0
        for sub in tqdm(os.scandir(__dir), desc="Ubuntu dialogs subfolders", total=len(os.listdir(__dir))):
            if number_subdir == self.MAX_NUMBER_SUBDIR:
                print("WARNING: Early stoping, only extracting {} directories".format(self.MAX_NUMBER_SUBDIR))
                return

            if sub.is_dir():
                number_subdir += 1
                for f in os.scandir(sub.path):
                    if f.name.endswith(".tsv"):
                        self.conversations.append({"lines": self.loadLines(f.path)})
vec2bin.py 文件源码 项目:hadan-gcloud 作者: youkpan 项目源码 文件源码 阅读 18 收藏 0 点赞 0 评论 0
def vec2bin(input_path, output_path):
    input_fd  = open(input_path, "rb")
    output_fd = open(output_path, "wb")

    header = input_fd.readline()
    output_fd.write(header)

    vocab_size, vector_size = map(int, header.split())

    for line in tqdm(range(vocab_size)):
        word = []
        while True:
            ch = input_fd.read(1)
            output_fd.write(ch)
            if ch == b' ':
                word = b''.join(word).decode('utf-8')
                break
            if ch != b'\n':
                word.append(ch)
        vector = np.fromstring(input_fd.readline(), sep=' ', dtype='float32')
        output_fd.write(vector.tostring())

    input_fd.close()
    output_fd.close()
03-evaluate.py 文件源码 项目:crema 作者: bmcfee 项目源码 文件源码 阅读 20 收藏 0 点赞 0 评论 0
def evaluate(input_path, n_jobs):

    aud, ann = zip(*crema.utils.get_ann_audio(input_path))

    test_idx = set(pd.read_json('index_test.json')['id'])

    # drop anything not in the test set
    ann = [ann_i for ann_i in ann if crema.utils.base(ann_i) in test_idx]
    aud = [aud_i for aud_i in aud if crema.utils.base(aud_i) in test_idx]

    stream = tqdm(zip(ann, aud), desc='Evaluating test set', total=len(ann))

    results = Parallel(n_jobs=n_jobs)(delayed(track_eval)(ann_i, aud_i)
                                      for ann_i, aud_i in stream)
    df = pd.DataFrame.from_dict(dict(results), orient='index')

    print('Results')
    print('-------')
    print(df.describe())

    df.to_json(os.path.join(OUTPUT_PATH, 'test_scores.json'))
trainer.py 文件源码 项目:treelstm.pytorch 作者: dasguptar 项目源码 文件源码 阅读 32 收藏 0 点赞 0 评论 0
def train(self, dataset):
        self.model.train()
        self.optimizer.zero_grad()
        total_loss = 0.0
        indices = torch.randperm(len(dataset))
        for idx in tqdm(range(len(dataset)),desc='Training epoch ' + str(self.epoch + 1) + ''):
            ltree, lsent, rtree, rsent, label = dataset[indices[idx]]
            linput, rinput = Var(lsent), Var(rsent)
            target = Var(map_label_to_target(label, dataset.num_classes))
            if self.args.cuda:
                linput, rinput = linput.cuda(), rinput.cuda()
                target = target.cuda()
            output = self.model(ltree, linput, rtree, rinput)
            loss = self.criterion(output, target)
            total_loss += loss.data[0]
            loss.backward()
            if idx % self.args.batchsize == 0 and idx > 0:
                self.optimizer.step()
                self.optimizer.zero_grad()
        self.epoch += 1
        return total_loss / len(dataset)

    # helper function for testing
trainer.py 文件源码 项目:treelstm.pytorch 作者: dasguptar 项目源码 文件源码 阅读 29 收藏 0 点赞 0 评论 0
def test(self, dataset):
        self.model.eval()
        total_loss = 0
        predictions = torch.zeros(len(dataset))
        indices = torch.arange(1, dataset.num_classes + 1)
        for idx in tqdm(range(len(dataset)),desc='Testing epoch  ' + str(self.epoch) + ''):
            ltree, lsent, rtree, rsent, label = dataset[idx]
            linput, rinput = Var(lsent, volatile=True), Var(rsent, volatile=True)
            target = Var(map_label_to_target(label, dataset.num_classes), volatile=True)
            if self.args.cuda:
                linput, rinput = linput.cuda(), rinput.cuda()
                target = target.cuda()
            output = self.model(ltree, linput, rtree, rinput)
            loss = self.criterion(output, target)
            total_loss += loss.data[0]
            output = output.data.squeeze().cpu()
            predictions[idx] = torch.dot(indices, torch.exp(output))
        return total_loss / len(dataset), predictions
CASIA.py 文件源码 项目:PyCasia 作者: lucaskjaero 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def load_dataset(self, dataset, verbose=True):
        """
        Load a directory of gnt files. Yields the image and label in tuples.
        :param dataset: The directory to load.
        :return:  Yields (Pillow.Image.Image, label) pairs.
        """
        assert self.get_dataset(dataset) is True, "Datasets aren't properly downloaded, " \
                                                  "rerun to try again or download datasets manually."

        if verbose:
            print("Loading %s" % dataset)

        dataset_path = self.base_dataset_path + dataset
        for path in tqdm(glob.glob(dataset_path + "/*.gnt")):
            for image, label in self.load_gnt_file(path):
                yield image, label
kaggle.py 文件源码 项目:catalearn 作者: Catalearn 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def __download_competition_file(self, competition, file_name, browser):

        url = 'https://www.kaggle.com/c/%s/download/%s' % (competition, file_name)
        res = browser.get(url, stream=True)

        total_size = int(res.headers.get('content-length', 0)); 

        if res.status_code != 200:
            print('error downloading %s' % file_name)
            return False

        file_name = os.path.basename(url)

        pbar = tqdm(total=total_size, unit='B', unit_scale=True, desc=file_name)
        chunk_size = 32 * 1024

        with open(file_name, 'wb') as file_handle:
            for data in res.iter_content(chunk_size):
                file_handle.write(data) 
                pbar.update(chunk_size)

        return True
make_video.py 文件源码 项目:traffic_detection_yolo2 作者: wAuner 项目源码 文件源码 阅读 21 收藏 0 点赞 0 评论 0
def frames2video(name, path):
    """
    Merges images in path into a video

    :param path: path with prediction images
    :return:
    """
    batch_size = 100
    fnames = os.listdir(path)
    fnames.sort()


    #images = np.array([plt.imread(os.path.join(path, fname)) for fname in fnames])
    # h, w, c = images[0].shape
    videowriter = imageio.get_writer(name + '_video.mp4', fps=25)

    for fname in tqdm.tqdm(fnames):
        videowriter.append_data(plt.imread(os.path.join(path, fname)))
    videowriter.close()


问题


面经


文章

微信
公众号

扫码关注公众号