LSTMEncDecAttn.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:mlpnlp-nmt 作者: mlpnlp 项目源码 文件源码
def __call__(self, hx, cx, xs, flag_train, args):
        if hx is None:
            hx = self.init_hx(xs)
        if cx is None:
            cx = self.init_hx(xs)

        # hx, cx ? (layer?, minibatch???????)?tensor
        # xs? (???, minibatch???????)?tensor
        # Note: chaFunc.n_step_lstm() ?????????dropout?????
        if args.chainer_version_check[0] == 2:
            hy, cy, ys = chaFunc.n_step_lstm(
                self.n_layers, self.dropout_rate, hx, cx, self.ws, self.bs, xs)
        else:
            hy, cy, ys = chaFunc.n_step_lstm(
                self.n_layers, self.dropout_rate, hx, cx, self.ws, self.bs, xs,
                train=flag_train, use_cudnn=self.use_cudnn)
        # hy, cy ? (layer?, minibatch???????) ?????
        # ys????????????????????
        # ???? (minibatch???????)
        # ??????????stack???????????chainer.Variable???
        # (???, minibatch???????)?tensor
        hlist = chaFunc.stack(ys)
        return hy, cy, hlist


# LSTM???????????????????????????????????
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号