def main():
model = utils.load_model(args)
new_model = conv_vh_decomposition(model, args)
new_model.save(args.save_model)
python类load_model()的实例源码
def main():
model = utils.load_model(args)
new_model = fc_decomposition(model, args)
new_model.save(args.save_model)
def main():
parser = argparse.ArgumentParser(description='chat utility')
parser.add_argument('--model_file', type=str,
help='model file')
parser.add_argument('--domain', type=str, default='object_division',
help='domain for the dialogue')
parser.add_argument('--context_file', type=str, default='',
help='context file')
parser.add_argument('--temperature', type=float, default=1.0,
help='temperature')
parser.add_argument('--num_types', type=int, default=3,
help='number of object types')
parser.add_argument('--num_objects', type=int, default=6,
help='total number of objects')
parser.add_argument('--max_score', type=int, default=10,
help='max score per object')
parser.add_argument('--score_threshold', type=int, default=6,
help='successful dialog should have more than score_threshold in score')
parser.add_argument('--seed', type=int, default=1,
help='random seed')
parser.add_argument('--smart_ai', action='store_true', default=False,
help='make AI smart again')
parser.add_argument('--ai_starts', action='store_true', default=False,
help='allow AI to start the dialog')
parser.add_argument('--ref_text', type=str,
help='file with the reference text')
args = parser.parse_args()
utils.set_seed(args.seed)
human = HumanAgent(domain.get_domain(args.domain))
alice_ty = LstmRolloutAgent if args.smart_ai else LstmAgent
ai = alice_ty(utils.load_model(args.model_file), args)
agents = [ai, human] if args.ai_starts else [human, ai]
dialog = Dialog(agents, args)
logger = DialogLogger(verbose=True)
# either take manually produced contextes, or relay on the ones from the dataset
if args.context_file == '':
ctx_gen = ManualContextGenerator(args.num_types, args.num_objects, args.max_score)
else:
ctx_gen = ContextGenerator(args.context_file)
chat = Chat(dialog, ctx_gen, logger)
chat.run()
def main():
parser = argparse.ArgumentParser(description='selfplaying script')
parser.add_argument('--alice_model_file', type=str,
help='Alice model file')
parser.add_argument('--bob_model_file', type=str,
help='Bob model file')
parser.add_argument('--context_file', type=str,
help='context file')
parser.add_argument('--temperature', type=float, default=1.0,
help='temperature')
parser.add_argument('--verbose', action='store_true', default=False,
help='print out converations')
parser.add_argument('--seed', type=int, default=1,
help='random seed')
parser.add_argument('--score_threshold', type=int, default=6,
help='successful dialog should have more than score_threshold in score')
parser.add_argument('--max_turns', type=int, default=20,
help='maximum number of turns in a dialog')
parser.add_argument('--log_file', type=str, default='',
help='log successful dialogs to file for training')
parser.add_argument('--smart_alice', action='store_true', default=False,
help='make Alice smart again')
parser.add_argument('--fast_rollout', action='store_true', default=False,
help='to use faster rollouts')
parser.add_argument('--rollout_bsz', type=int, default=100,
help='rollout batch size')
parser.add_argument('--rollout_count_threshold', type=int, default=3,
help='rollout count threshold')
parser.add_argument('--smart_bob', action='store_true', default=False,
help='make Bob smart again')
parser.add_argument('--ref_text', type=str,
help='file with the reference text')
parser.add_argument('--domain', type=str, default='object_division',
help='domain for the dialogue')
args = parser.parse_args()
utils.set_seed(args.seed)
alice_model = utils.load_model(args.alice_model_file)
alice_ty = get_agent_type(alice_model, args.smart_alice, args.fast_rollout)
alice = alice_ty(alice_model, args, name='Alice')
bob_model = utils.load_model(args.bob_model_file)
bob_ty = get_agent_type(bob_model, args.smart_bob, args.fast_rollout)
bob = bob_ty(bob_model, args, name='Bob')
dialog = Dialog([alice, bob], args)
logger = DialogLogger(verbose=args.verbose, log_file=args.log_file)
ctx_gen = ContextGenerator(args.context_file)
selfplay = SelfPlay(dialog, ctx_gen, args, logger)
selfplay.run()
def main():
parser = argparse.ArgumentParser(description='Negotiator')
parser.add_argument('--dataset', type=str, default='./data/negotiate/val.txt',
help='location of the dataset')
parser.add_argument('--model_file', type=str,
help='model file')
parser.add_argument('--smart_ai', action='store_true', default=False,
help='to use rollouts')
parser.add_argument('--seed', type=int, default=1,
help='random seed')
parser.add_argument('--temperature', type=float, default=1.0,
help='temperature')
parser.add_argument('--domain', type=str, default='object_division',
help='domain for the dialogue')
parser.add_argument('--log_file', type=str, default='',
help='log file')
args = parser.parse_args()
utils.set_seed(args.seed)
model = utils.load_model(args.model_file)
ai = LstmAgent(model, args)
logger = DialogLogger(verbose=True, log_file=args.log_file)
domain = get_domain(args.domain)
score_func = rollout if args.smart_ai else likelihood
dataset, sents = read_dataset(args.dataset)
ranks, n, k = 0, 0, 0
for ctx, dialog in dataset:
start_time = time.time()
# start new conversation
ai.feed_context(ctx)
for sent, you in dialog:
if you:
# if it is your turn to say, take the target word and compute its rank
rank = compute_rank(sent, sents, ai, domain, args.temperature, score_func)
# compute lang_h for the groundtruth sentence
enc = ai._encode(sent, ai.model.word_dict)
_, ai.lang_h, lang_hs = ai.model.score_sent(enc, ai.lang_h, ai.ctx_h, args.temperature)
# save hidden states and the utterance
ai.lang_hs.append(lang_hs)
ai.words.append(ai.model.word2var('YOU:'))
ai.words.append(Variable(enc))
ranks += rank
n += 1
else:
ai.read(sent)
k += 1
time_elapsed = time.time() - start_time
logger.dump('dialogue %d | avg rank %.3f | raw %d/%d | time %.3f' % (k, 1. * ranks / n, ranks, n, time_elapsed))
logger.dump('final avg rank %.3f' % (1. * ranks / n))
def main():
parser = argparse.ArgumentParser(description='testing script')
parser.add_argument('--data', type=str, default='data/negotiate',
help='location of the data corpus')
parser.add_argument('--unk_threshold', type=int, default=20,
help='minimum word frequency to be in dictionary')
parser.add_argument('--model_file', type=str,
help='pretrained model file')
parser.add_argument('--seed', type=int, default=1,
help='random seed')
parser.add_argument('--hierarchical', action='store_true', default=False,
help='use hierarchical model')
parser.add_argument('--bsz', type=int, default=16,
help='batch size')
parser.add_argument('--cuda', action='store_true', default=False,
help='use CUDA')
args = parser.parse_args()
device_id = utils.use_cuda(args.cuda)
utils.set_seed(args.seed)
corpus = data.WordCorpus(args.data, freq_cutoff=args.unk_threshold, verbose=True)
model = utils.load_model(args.model_file)
crit = Criterion(model.word_dict, device_id=device_id)
sel_crit = Criterion(model.item_dict, device_id=device_id,
bad_toks=['<disconnect>', '<disagree>'])
testset, testset_stats = corpus.test_dataset(args.bsz, device_id=device_id)
test_loss, test_select_loss = 0, 0
N = len(corpus.word_dict)
for batch in testset:
# run forward on the batch, produces output, hidden, target,
# selection output and selection target
out, hid, tgt, sel_out, sel_tgt = Engine.forward(model, batch, volatile=False)
# compute LM and selection losses
test_loss += tgt.size(0) * crit(out.view(-1, N), tgt).data[0]
test_select_loss += sel_crit(sel_out, sel_tgt).data[0]
test_loss /= testset_stats['nonpadn']
test_select_loss /= len(testset)
print('testloss %.3f | testppl %.3f' % (test_loss, np.exp(test_loss)))
print('testselectloss %.3f | testselectppl %.3f' % (test_select_loss, np.exp(test_select_loss)))
def _train(net, training_data, validation_data, model_name, learning_rate,
max_epochs, min_improvement):
min_learning_rate = 1e-6
best_validation_ppl = np.inf
divide = False
for epoch in range(1, max_epochs + 1):
epoch_start = time()
print "\n======= EPOCH %s =======" % epoch
print "\tLearning rate is %s" % learning_rate
train_ppl = _process_corpus(net, training_data, mode='train',
learning_rate=learning_rate)
print "\tTrain PPL is %.3f" % train_ppl
validation_ppl = _process_corpus(net, validation_data, mode='test')
print "\tValidation PPL is %.3f" % validation_ppl
print "\tTime taken: %ds" % (time() - epoch_start)
if np.log(validation_ppl) * min_improvement > np.log(best_validation_ppl):
if not divide:
divide = True
print "\tStarting to reduce the learning rate..."
if validation_ppl > best_validation_ppl:
print "\tLoading best model."
net = utils.load_model("../out/" + model_name)
else:
if validation_ppl < best_validation_ppl:
print "\tSaving model."
net.save("../out/" + model_name, final = True)
break
else:
print "\tNew best model! Saving..."
best_validation_ppl = validation_ppl
final = learning_rate / 2. < min_learning_rate or epoch == max_epochs
net.save("../out/" + model_name, final)
if divide:
learning_rate /= 2.
if learning_rate < min_learning_rate:
break
print "-"*30
print "Finished training."
print "Best validation PPL is %.3f\n\n" % best_validation_ppl
def train():
log.info('loading dataset...')
train_data=TextIterator(train_file,n_batch=batch_size,maxlen=maxlen)
valid_data = TextIterator(valid_file, n_batch=batch_size, maxlen=maxlen)
test_data = TextIterator(test_file, n_batch=batch_size, maxlen=maxlen,mode=2)
log.info('building models....')
model=RCNNModel(n_input=n_input,n_vocab=VOCABULARY_SIZE,n_hidden=n_hidden,cell='gru', optimizer=optimizer,dropout=dropout,sim=sim,maxlen=maxlen,batch_size=batch_size)
start=time.time()
if os.path.isfile(model_dir):
print 'loading checkpoint parameters....',model_dir
model=load_model(model_dir,model)
if goto_line!=0:
train_data.goto_line(goto_line)
print 'goto line:',goto_line
log.info('training start...')
for epoch in xrange(NEPOCH):
costs=0
idx=0
error_rate_list=[]
try:
for (x,xmask),(y,ymask),label in train_data:
idx+=1
if x.shape[-1]!=batch_size:
continue
cost,error_rate=model.train(x,xmask,y,ymask,label,lr)
#print cost,error_rate
#projected_output,cost= model.test(x, xmask, y, ymask,label)
#print "projected_output shape:", projected_output.shape
##print "cnn_output shape:",cnn_output.shape
#print "cost:",cost
costs+=cost
error_rate_list.append(error_rate)
if np.isnan(cost) or np.isinf(cost):
print 'Nan Or Inf detected!'
print x.shape,y.shape
print cost,error_rate
return -1
if idx % disp_freq==0:
log.info('epoch: %d, idx: %d cost: %.3f, Accuracy: %.3f '%(epoch,idx,costs/idx, np.mean(list(itertools.chain.from_iterable(error_rate_list)))))
if idx%dump_freq==0:
save_model('./model/parameters_%.2f.pkl'%(time.time()-start),model)
except Exception:
print np.max(x),np.max(y)
print x.shape,y.shape
evaluate(train_data,valid_data, test_data,model)
log.info("Finished. Time = " +str(time.time()-start))
def test():
log.info('loading dataset...')
log.info('building models....')
model=RCNNModel(n_input=n_input,n_vocab=VOCABULARY_SIZE,n_hidden=n_hidden,cell='gru',optimizer=optimizer,dropout=dropout,sim=sim,maxlen=maxlen,batch_size=batch_size)
log.info('training start....')
start=time.time()
if os.path.isfile(model_dir):
print 'loading checkpoint parameters....',model_dir
model=load_model(model_dir,model)
for epoch in xrange(NEPOCH):
costs=[]
idx=0
acc_list=[]
train_data = TextIterator(train_file+".train."+str(epoch), n_batch=batch_size, maxlen=maxlen)
valid_data = TextIterator(train_file+".valid."+str(epoch), n_batch=batch_size, maxlen=maxlen)
for (x,xmask),(y,ymask),label in train_data:
idx+=1
if x.shape[-1]!=batch_size:
continue
#print x.shape
cost,acc=model.predict(x,xmask,y,ymask,label)
#print cost
#projected_output,cost= model.test(x, xmask, y, ymask,label)
#print "projected_output shape:", projected_output.shape
##print "cnn_output shape:",cnn_output.shape
#print "cost:",cost
costs.append(cost)
acc_list.append(acc)
if np.isnan(np.mean(cost)) or np.isinf(np.mean(cost)):
print 'Nan Or Inf detected!'
print "x:",x
print x.shape
print 'y:',y
print y.shape
return -1
#log.info('dumping parameters....')
#save_model('./model/parameters_%.2f.pkl'%(time.time()-start),model)
log.info('epoch: %d, cost: %.3f, Accuracy: %.3f ' % (
epoch,np.mean(list(itertools.chain.from_iterable(costs))), np.mean(list(itertools.chain.from_iterable(acc_list)))))
loss, acc = evaluate(valid_data, model)
log.info('validation cost: %.3f, Accuracy: %.3f' % (loss,acc))
log.info("Finished. Time = " +str(time.time()-start))