extensions.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:dl4mt-multi 作者: nyu-dl 项目源码 文件源码
def do(self, callback_name, *args):
        probs = {}
        print ''
        logger.info(" Computing log-probs...")
        start = time.time()
        for cg_name, stream in self.streams.iteritems():
            probs[cg_name] = list()
            src_id, trg_id = p_(cg_name)

            # handle multi-source stream
            src_idx = self.enc_ids.index(src_id)
            trg_idx = self.dec_ids.index(trg_id)

            for i, batch in enumerate(stream.get_epoch_iterator()):
                batch_size = batch[0].shape[0]
                src_sel = numpy.zeros(
                    (batch_size, self.num_encs)).astype(theano.config.floatX)
                src_sel[:, src_idx] = 1.
                trg_sel = numpy.zeros(
                    (batch_size, self.num_decs)).astype(theano.config.floatX)
                trg_sel[:, trg_idx] = 1.

                inps = [batch[0].T, batch[1].T, batch[2].T, batch[3].T,
                        src_sel, trg_sel]

                pprobs = self.f_log_probs[cg_name](*inps)
                probs[cg_name].append(pprobs.tolist())

                if numpy.isnan(numpy.mean(probs[cg_name])):
                    import ipdb
                    ipdb.set_trace()

            print 'logprob for CG [{}]: {}'.format(
                cg_name, numpy.mean(probs[cg_name]))

        print "took {} seconds.".format(time.time()-start)
        records = [('logprob_' + k, numpy.mean(v))
                   for k, v in probs.iteritems()]
        self.add_records(self.main_loop.log, records)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号