plot_attn_mask.py 文件源码

python
阅读 14 收藏 0 点赞 0 评论 0

项目:seq2seq_parser 作者: trangham283 项目源码 文件源码
def plot_attn_mat(model_dev, all_examples):
    model_dev.batch_size = len(all_examples)        
    token_ids = [x[0] for x in all_examples]
    gold_ids = [x[1] for x in all_examples]
    dec_ids = [[]] * len(token_ids)
    encoder_inputs, decoder_inputs, target_weights = model_dev.get_decode_batch(
            {bucket_id: zip(token_ids, dec_ids)}, bucket_id)
    _, _, output_logits, attns = model_dev.step(sess, encoder_inputs, decoder_inputs,
            target_weights, bucket_id, True)
    #_, _, output_logits, attns = model_dev.step_with_attn(sess, encoder_inputs, decoder_inputs,target_weights, bucket_id, True)
    outputs = [np.argmax(logit, axis=1) for logit in output_logits]
    to_decode = np.array(outputs).T
    sent_id = 0
    parse = list(to_decode[sent_id, :])
    parse_all = parse[:]
    if data_utils.EOS_ID in parse: parse = parse[:parse.index(data_utils.EOS_ID)]
    decoded_parse = []
    decoded_parse_all = []
    for output in parse:
        if output < len(rev_parse_vocab): decoded_parse.append(tf.compat.as_str(rev_parse_vocab[output]))
        else: decoded_parse.append("_UNK") 
    for output in parse_all:
        if output < len(rev_parse_vocab): decoded_parse_all.append(tf.compat.as_str(rev_parse_vocab[output]))
        else: decoded_parse_all.append("_UNK") 
    gold_parse = [tf.compat.as_str(rev_parse_vocab[output]) for output in gold_ids[sent_id]]
    sent_text = [tf.compat.as_str(rev_sent_vocab[output]) for output in token_ids[sent_id]]
    mat = attns[:,0,:].T
    return encoder_inputs, sent_text, gold_parse, decoded_parse, mat
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号