dern.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:der-network 作者: soskek 项目源码 文件源码
def __init__(self, vocab, args):
        def get_initialW_X(shape):
            return np.random.normal(0, (2.0/(sum(shape)))**0.5, shape).astype(np.float32)

        super(DERN, self).__init__(
            # Word Embedding
            embed=L.EmbedID(len(vocab), args.n_units),

            # bi-LSTMs
            f_LSTM=L.LSTM(args.n_units, args.n_units),  # for article
            b_LSTM=L.LSTM(args.n_units, args.n_units),
            Q_f_LSTM=L.LSTM(args.n_units, args.n_units),  # for query
            Q_b_LSTM=L.LSTM(args.n_units, args.n_units),

            # Matrices and vectors
            W_hd=L.Linear(4*args.n_units, args.n_units, initialW=get_initialW_X((args.n_units, 4*args.n_units))),
            W_dm=L.Linear(args.n_units, args.n_units, initialW=get_initialW_X((args.n_units, args.n_units))),
            m=L.Linear(args.n_units, 1, initialW=get_initialW_X((1, args.n_units))),
            W_hq=L.Linear(4 * args.n_units, args.n_units, initialW=get_initialW_X((args.n_units, 4*args.n_units))),
            W_hu=L.Linear(4 * args.n_units, args.n_units, initialW=get_initialW_X((args.n_units, 4*args.n_units))),
            W_dv=L.Linear(args.n_units, args.n_units, initialW=get_initialW_X((args.n_units, args.n_units))),
            W_dx=L.Linear(args.n_units, args.n_units, initialW=get_initialW_X((args.n_units, args.n_units))),
            W_dxQ=L.Linear(args.n_units, args.n_units, initialW=get_initialW_X((args.n_units, args.n_units))),

            b_v2=L.Linear(1, args.n_units, initialW=get_initialW_X((args.n_units, 1)))
        )

        self.args = args
        self.n_vocab = len(vocab)
        self.n_units = args.n_units
        self.dropout_ratio = args.d_ratio

        self.PH_id = vocab["@placeholder"]
        self.eos_id = vocab["<eos>"]
        self.bos_id = vocab["<bos>"]
        self.boq_id = vocab["<boq>"]
        self.BOQ_tok_batch = self.xp.array([self.boq_id], dtype=np.int32)
        self.NULL_id = vocab["NULL_tok"]
        self.NULL_tok = self.xp.array(self.NULL_id, dtype=np.int32)

        self.initialize_additionally()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号