def do(self, callback_name, *args):
probs = {}
print ''
logger.info(" Computing log-probs...")
start = time.time()
for cg_name, stream in self.streams.iteritems():
probs[cg_name] = list()
src_id, trg_id = p_(cg_name)
# handle multi-source stream
src_idx = self.enc_ids.index(src_id)
trg_idx = self.dec_ids.index(trg_id)
for i, batch in enumerate(stream.get_epoch_iterator()):
batch_size = batch[0].shape[0]
src_sel = numpy.zeros(
(batch_size, self.num_encs)).astype(theano.config.floatX)
src_sel[:, src_idx] = 1.
trg_sel = numpy.zeros(
(batch_size, self.num_decs)).astype(theano.config.floatX)
trg_sel[:, trg_idx] = 1.
inps = [batch[0].T, batch[1].T, batch[2].T, batch[3].T,
src_sel, trg_sel]
pprobs = self.f_log_probs[cg_name](*inps)
probs[cg_name].append(pprobs.tolist())
if numpy.isnan(numpy.mean(probs[cg_name])):
import ipdb
ipdb.set_trace()
print 'logprob for CG [{}]: {}'.format(
cg_name, numpy.mean(probs[cg_name]))
print "took {} seconds.".format(time.time()-start)
records = [('logprob_' + k, numpy.mean(v))
for k, v in probs.iteritems()]
self.add_records(self.main_loop.log, records)
评论列表
文章目录