attention_decoder.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:sequencing 作者: SwordYork 项目源码 文件源码
def mask_finished(finished, now_, prev_):
        mask = tf.expand_dims(tf.to_float(finished), 1)

        if isinstance(prev_, tuple):
            # tuple states
            next_ = []
            for ns, s in zip(now_, prev_):
                # fucking LSTMStateTuple
                if isinstance(ns, LSTMStateTuple):
                    next_.append(
                        LSTMStateTuple(c=(1. - mask) * ns.c + mask * s.c,
                                       h=(1. - mask) * ns.h + mask * s.h))
                else:
                    next_.append((1. - mask) * ns + mask * s)
            next_ = tuple(next_)
        else:
            next_ = (1. - mask) * now_ + mask * prev_

        return next_
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号