algorithm.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:dl4mt-multi 作者: nyu-dl 项目源码 文件源码
def process_batch(self, batch):
        """
        Execution of an update step, infer cg_id from selectors, and pick
        corresponding computational graph, and apply batch to the CG.
        """
        cg_id = self.get_cg_id_from_selectors(batch['src_selector'][0],
                                              batch['trg_selector'][0])

        # Apply input replacement with <UNK> if necessary
        if self.drop_input[cg_id] > 0.0:
            num_els = numpy.prod(batch['source'].shape)
            num_reps = max(1, int(num_els * self.drop_input[cg_id]))
            replace_idx = numpy.random.choice(num_els, num_reps, replace=False)
            # TODO: set it according to unk_id in config
            batch['source'][numpy.unravel_index(
                replace_idx, batch['source'].shape)] = 1

        ordered_batch = [batch[v.name] for v in self.algorithms[cg_id].inputs]

        # To save memory, we may combine f_update and f_grad_shared
        if self.f_grad_shareds[cg_id] is None:
            inps = [self.learning_rate] + ordered_batch
            cost = self.f_updates[cg_id](*inps)
            self._cost = ('cost_' + cg_id, cost)
        else:
            cost = self.f_grad_shareds[cg_id](*ordered_batch)
            self._cost = ('cost_' + cg_id, cost)
            self.f_updates[cg_id](self.learning_rate)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号