train.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:monogreedy 作者: jinjunqi 项目源码 文件源码
def eva_a_phi(phi):
    na, nnh, nh, nw = phi

    # choose a dataset to train (mscoco, flickr8k, flickr30k)
    dataset = 'mscoco'
    data_dir = osp.join(DATA_ROOT, dataset)

    from model.ra import Model
    # settings
    mb = 64  # mini-batch size
    lr = 0.0002  # learning rate
    # nh = 512  # size of LSTM's hidden size
    # nnh = 512  # hidden size of attention mlp
    # nw = 512  # size of word embedding vector
    # na = 512  # size of the region features after dimensionality reduction
    name = 'ra'  # model name, just setting it to 'ra' is ok. 'ra'='region attention'
    vocab_freq = 'freq5'  # use the vocabulary that filtered out words whose frequences are less than 5

    print '... loading data {}'.format(dataset)
    train_set = Reader(batch_size=mb, data_split='train', vocab_freq=vocab_freq, stage='train',
                       data_dir=data_dir, feature_file='features_30res.h5', topic_switch='off') # change 0, 1000, 82783
    valid_set = Reader(batch_size=1, data_split='val', vocab_freq=vocab_freq, stage='val',
                       data_dir=data_dir, feature_file='features_30res.h5',
                       caption_switch='off', topic_switch='off') # change 0, 10, 5000

    npatch, nimg = train_set.features.shape[1:]
    nout = len(train_set.vocab)
    save_dir = '{}-nnh{}-nh{}-nw{}-na{}-mb{}-V{}'.\
        format(dataset.lower(), nnh, nh, nw, na, mb, nout)
    save_dir = osp.join(SAVE_ROOT, save_dir)

    model_file, m = find_last_snapshot(save_dir, resume_training=False)
    os.system('cp model/ra.py {}/'.format(save_dir))
    logger = Logger(save_dir)
    logger.info('... building')
    model = Model(name=name, nimg=nimg, nnh=nnh, nh=nh, na=na, nw=nw, nout=nout, npatch=npatch, model_file=model_file)

    # start training
    bs = BeamSearch([model], beam_size=1, num_cadidates=100, max_length=20)
    best = train(model, bs, train_set, valid_set, save_dir, lr,
                 display=100, starting=m, endding=20, validation=2000, life=10, logger=logger) # change dis1,100; va 2,2000; life 0,10;
    average_models(best=best, L=6, model_dir=save_dir, model_name=name+'.h5') # L 1, 6

    # evaluation
    np.save('data_dir', data_dir)
    np.save('save_dir', save_dir)

    os.system('python valid_time.py')

    scores = np.load('scores.npy')
    running_time = np.load('running_time.npy')
    print 'cider:', scores[-1], 'B1-4,C:', scores, 'running time:', running_time

    return scores, running_time
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号