LSTMEncDecAttn.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:mlpnlp-nmt 作者: mlpnlp 项目源码 文件源码
def __init__(self, n_layers,  # ??
                 in_size,  # ?????????
                 out_size,  # ?????(?????????????)
                 dropout_rate,
                 name="",
                 use_cudnn=True):
        weights = []
        direction = 1  # ????????????????????1???
        t_name = name
        if name is not "":
            t_name = '%s_' % (name)

        for i in six.moves.range(n_layers):
            for di in six.moves.range(direction):
                weight = chainer.Link()
                for j in six.moves.range(8):
                    if i == 0 and j < 4:
                        w_in = in_size
                    elif i > 0 and j < 4:
                        w_in = out_size * direction
                    else:
                        w_in = out_size
                    weight.add_param('%sw%d' % (t_name, j), (out_size, w_in))
                    weight.add_param('%sb%d' % (t_name, j), (out_size,))
                    getattr(weight, '%sw%d' %
                            (t_name, j)).data[...] = np.random.normal(
                                0, np.sqrt(1. / w_in), (out_size, w_in))
                    getattr(weight, '%sb%d' % (t_name, j)).data[...] = 0
                weights.append(weight)

        super(NStepLSTMpp, self).__init__(*weights)

        self.n_layers = n_layers
        self.dropout_rate = dropout_rate
        self.use_cudnn = use_cudnn
        self.out_size = out_size
        self.direction = direction
        self.ws = [[getattr(w, '%sw0' % (t_name)),
                    getattr(w, '%sw1' % (t_name)),
                    getattr(w, '%sw2' % (t_name)),
                    getattr(w, '%sw3' % (t_name)),
                    getattr(w, '%sw4' % (t_name)),
                    getattr(w, '%sw5' % (t_name)),
                    getattr(w, '%sw6' % (t_name)),
                    getattr(w, '%sw7' % (t_name))] for w in self]
        self.bs = [[getattr(w, '%sb0' % (t_name)),
                    getattr(w, '%sb1' % (t_name)),
                    getattr(w, '%sb2' % (t_name)),
                    getattr(w, '%sb3' % (t_name)),
                    getattr(w, '%sb4' % (t_name)),
                    getattr(w, '%sb5' % (t_name)),
                    getattr(w, '%sb6' % (t_name)),
                    getattr(w, '%sb7' % (t_name))] for w in self]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号