GAReaderMatrix.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:fg-gating 作者: kimiyoung 项目源码 文件源码
def __init__(self, K, vocab_size, num_chars, W_init, regularizer, rlambda, 
            nhidden, embed_dim, dropout, train_emb, subsample, char_dim, use_feat, feat_cnt,
            save_attn=False):
        self.nhidden = nhidden
        self.embed_dim = embed_dim
        self.dropout = dropout
        self.train_emb = train_emb
        self.subsample = subsample
        self.char_dim = char_dim
        self.learning_rate = LEARNING_RATE
        self.num_chars = num_chars
        self.use_feat = use_feat
        self.feat_cnt = feat_cnt
        self.save_attn = save_attn

        norm = lasagne.regularization.l2 if regularizer=='l2' else lasagne.regularization.l1
        self.use_chars = self.char_dim!=0
        if W_init is None: W_init = lasagne.init.GlorotNormal().sample((vocab_size, self.embed_dim))

        doc_var, query_var, cand_var = T.itensor3('doc'), T.itensor3('quer'), \
                T.wtensor3('cand')
        docmask_var, qmask_var, candmask_var = T.bmatrix('doc_mask'), T.bmatrix('q_mask'), \
                T.bmatrix('c_mask')
        target_var = T.ivector('ans')
        feat_var = T.imatrix('feat')
        doc_toks, qry_toks= T.imatrix('dchars'), T.imatrix('qchars')
        tok_var, tok_mask = T.imatrix('tok'), T.bmatrix('tok_mask')
        cloze_var = T.ivector('cloze')
        match_feat_var = T.itensor3('match_feat')
        use_char_var = T.tensor3('use_char')
        use_char_q_var = T.tensor3('use_char_q')
        self.inps = [doc_var, doc_toks, query_var, qry_toks, cand_var, target_var, docmask_var,
                qmask_var, tok_var, tok_mask, candmask_var, feat_var, cloze_var, match_feat_var, use_char_var, use_char_q_var]

        if rlambda> 0.: W_pert = W_init + lasagne.init.GlorotNormal().sample(W_init.shape)
        else: W_pert = W_init
        self.predicted_probs, predicted_probs_val, self.network, W_emb, attentions = (
                self.build_network(K, vocab_size, W_pert))

        self.loss_fn = T.nnet.categorical_crossentropy(self.predicted_probs, target_var).mean() + \
                rlambda*norm(W_emb-W_init)
        self.eval_fn = lasagne.objectives.categorical_accuracy(self.predicted_probs, 
                target_var).mean()

        loss_fn_val = T.nnet.categorical_crossentropy(predicted_probs_val, target_var).mean() + \
                rlambda*norm(W_emb-W_init)
        eval_fn_val = lasagne.objectives.categorical_accuracy(predicted_probs_val, 
                target_var).mean()

        self.params = L.get_all_params(self.network, trainable=True)

        updates = lasagne.updates.adam(self.loss_fn, self.params, learning_rate=self.learning_rate)

        self.train_fn = theano.function(self.inps,
                [self.loss_fn, self.eval_fn, self.predicted_probs], 
                updates=updates,
                on_unused_input='ignore')
        self.validate_fn = theano.function(self.inps, 
                [loss_fn_val, eval_fn_val, predicted_probs_val]+attentions,
                on_unused_input='ignore')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号