decoder.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:sgnmt 作者: ucam-smt 项目源码 文件源码
def initial_states(self, batch_size, *args, **kwargs):
        """Returns the initial state depending on ``init_strategy``."""
        attended = kwargs['attended']
        if self.init_strategy == 'constant':
            initial_state = [tensor.repeat(self.parameters[2][None, :],
                                           batch_size,
                                           0)]
        elif self.init_strategy == 'last':
            initial_state = self.initial_transformer.apply(
                attended[0, :, -self.attended_dim:])
        elif self.init_strategy == 'average':
            initial_state = self.initial_transformer.apply(
                attended[:, :, -self.attended_dim:].mean(0))  
        else:
            logging.fatal("dec_init parameter %s invalid" % self.init_strategy)
        return initial_state
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号