def eva_a_phi(phi):
na, nnh, nh, nw = phi
# choose a dataset to train (mscoco, flickr8k, flickr30k)
dataset = 'mscoco'
data_dir = osp.join(DATA_ROOT, dataset)
from model.ra import Model
# settings
mb = 64 # mini-batch size
lr = 0.0002 # learning rate
# nh = 512 # size of LSTM's hidden size
# nnh = 512 # hidden size of attention mlp
# nw = 512 # size of word embedding vector
# na = 512 # size of the region features after dimensionality reduction
name = 'ra' # model name, just setting it to 'ra' is ok. 'ra'='region attention'
vocab_freq = 'freq5' # use the vocabulary that filtered out words whose frequences are less than 5
print '... loading data {}'.format(dataset)
train_set = Reader(batch_size=mb, data_split='train', vocab_freq=vocab_freq, stage='train',
data_dir=data_dir, feature_file='features_30res.h5', topic_switch='off') # change 0, 1000, 82783
valid_set = Reader(batch_size=1, data_split='val', vocab_freq=vocab_freq, stage='val',
data_dir=data_dir, feature_file='features_30res.h5',
caption_switch='off', topic_switch='off') # change 0, 10, 5000
npatch, nimg = train_set.features.shape[1:]
nout = len(train_set.vocab)
save_dir = '{}-nnh{}-nh{}-nw{}-na{}-mb{}-V{}'.\
format(dataset.lower(), nnh, nh, nw, na, mb, nout)
save_dir = osp.join(SAVE_ROOT, save_dir)
model_file, m = find_last_snapshot(save_dir, resume_training=False)
os.system('cp model/ra.py {}/'.format(save_dir))
logger = Logger(save_dir)
logger.info('... building')
model = Model(name=name, nimg=nimg, nnh=nnh, nh=nh, na=na, nw=nw, nout=nout, npatch=npatch, model_file=model_file)
# start training
bs = BeamSearch([model], beam_size=1, num_cadidates=100, max_length=20)
best = train(model, bs, train_set, valid_set, save_dir, lr,
display=100, starting=m, endding=20, validation=2000, life=10, logger=logger) # change dis1,100; va 2,2000; life 0,10;
average_models(best=best, L=6, model_dir=save_dir, model_name=name+'.h5') # L 1, 6
# evaluation
np.save('data_dir', data_dir)
np.save('save_dir', save_dir)
os.system('python valid_time.py')
scores = np.load('scores.npy')
running_time = np.load('running_time.npy')
print 'cider:', scores[-1], 'B1-4,C:', scores, 'running time:', running_time
return scores, running_time
评论列表
文章目录