def __init__(self, K, vocab_size, num_chars, W_init,
nhidden, embed_dim, dropout, train_emb, char_dim, use_feat, gating_fn,
save_attn=False):
self.nhidden = nhidden
self.embed_dim = embed_dim
self.dropout = dropout
self.train_emb = train_emb
self.char_dim = char_dim
self.learning_rate = LEARNING_RATE
self.num_chars = num_chars
self.use_feat = use_feat
self.save_attn = save_attn
self.gating_fn = gating_fn
self.use_chars = self.char_dim!=0
if W_init is None: W_init = lasagne.init.GlorotNormal().sample((vocab_size, self.embed_dim))
doc_var, query_var, cand_var = T.itensor3('doc'), T.itensor3('quer'), \
T.wtensor3('cand')
docmask_var, qmask_var, candmask_var = T.bmatrix('doc_mask'), T.bmatrix('q_mask'), \
T.bmatrix('c_mask')
target_var = T.ivector('ans')
feat_var = T.imatrix('feat')
doc_toks, qry_toks= T.imatrix('dchars'), T.imatrix('qchars')
tok_var, tok_mask = T.imatrix('tok'), T.bmatrix('tok_mask')
cloze_var = T.ivector('cloze')
self.inps = [doc_var, doc_toks, query_var, qry_toks, cand_var, target_var, docmask_var,
qmask_var, tok_var, tok_mask, candmask_var, feat_var, cloze_var]
self.predicted_probs, predicted_probs_val, self.network, W_emb, attentions = (
self.build_network(K, vocab_size, W_init))
self.loss_fn = T.nnet.categorical_crossentropy(self.predicted_probs, target_var).mean()
self.eval_fn = lasagne.objectives.categorical_accuracy(self.predicted_probs,
target_var).mean()
loss_fn_val = T.nnet.categorical_crossentropy(predicted_probs_val, target_var).mean()
eval_fn_val = lasagne.objectives.categorical_accuracy(predicted_probs_val,
target_var).mean()
self.params = L.get_all_params(self.network, trainable=True)
updates = lasagne.updates.adam(self.loss_fn, self.params, learning_rate=self.learning_rate)
self.train_fn = theano.function(self.inps,
[self.loss_fn, self.eval_fn, self.predicted_probs],
updates=updates,
on_unused_input='warn')
self.validate_fn = theano.function(self.inps,
[loss_fn_val, eval_fn_val, predicted_probs_val]+attentions,
on_unused_input='warn')
评论列表
文章目录