def plot_attn_mat(model_dev, all_examples):
model_dev.batch_size = len(all_examples)
token_ids = [x[0] for x in all_examples]
gold_ids = [x[1] for x in all_examples]
dec_ids = [[]] * len(token_ids)
encoder_inputs, decoder_inputs, target_weights = model_dev.get_decode_batch(
{bucket_id: zip(token_ids, dec_ids)}, bucket_id)
_, _, output_logits, attns = model_dev.step(sess, encoder_inputs, decoder_inputs,
target_weights, bucket_id, True)
#_, _, output_logits, attns = model_dev.step_with_attn(sess, encoder_inputs, decoder_inputs,target_weights, bucket_id, True)
outputs = [np.argmax(logit, axis=1) for logit in output_logits]
to_decode = np.array(outputs).T
sent_id = 0
parse = list(to_decode[sent_id, :])
parse_all = parse[:]
if data_utils.EOS_ID in parse: parse = parse[:parse.index(data_utils.EOS_ID)]
decoded_parse = []
decoded_parse_all = []
for output in parse:
if output < len(rev_parse_vocab): decoded_parse.append(tf.compat.as_str(rev_parse_vocab[output]))
else: decoded_parse.append("_UNK")
for output in parse_all:
if output < len(rev_parse_vocab): decoded_parse_all.append(tf.compat.as_str(rev_parse_vocab[output]))
else: decoded_parse_all.append("_UNK")
gold_parse = [tf.compat.as_str(rev_parse_vocab[output]) for output in gold_ids[sent_id]]
sent_text = [tf.compat.as_str(rev_sent_vocab[output]) for output in token_ids[sent_id]]
mat = attns[:,0,:].T
return encoder_inputs, sent_text, gold_parse, decoded_parse, mat
评论列表
文章目录