def get_candidates_representations_in_sentence(self, sentence_candidate_answers, sentence_attentioned_hidden_states):
candidate_answer_num=tf.gather(tf.shape(sentence_candidate_answers), 0)
logging.warn('candidate_answer_num:{}'.format(candidate_answer_num))
logging.warn('sentence_candidate_answers:{}'.format(sentence_candidate_answers))
candidate_answer_nodeids=tf.gather(sentence_candidate_answers, 0) #a node idx list
candidate_answer_hidden_list=tf.gather(sentence_attentioned_hidden_states, candidate_answer_nodeids)
candidate_final_representations=self.get_candidate_answer_final_representations(candidate_answer_hidden_list)
candidates_final_representations=tf.expand_dims(candidate_final_representations, 0)
idx_cand=tf.constant(1)
def _recurse_candidate_answer(candidate_final_representations, idx_cand):
cur_candidate_answer_nodeids=tf.gather(sentence_candidate_answers, idx_cand)
cur_candidate_answer_hidden_list=tf.gather(sentence_attentioned_hidden_states, cur_candidate_answer_nodeids)
cur_candidate_final_representations=tf.expand_dims(
self.get_candidate_answer_final_representations(cur_candidate_answer_hidden_list), 0)
candidate_final_representations=tf.concat([candidate_final_representations, cur_candidate_final_representations], axis=0)
idx_cand=tf.add(idx_cand,1)
return candidate_final_representations, idx_cand
loop_cond=lambda a1,idx:tf.less(idx, candidate_answer_num)
loop_vars=[candidates_final_representations, idx_cand]
candidates_final_representations, idx_cand=tf.while_loop(loop_cond, _recurse_candidate_answer, loop_vars,
shape_invariants=[tf.TensorShape([None, 2*self.config.hidden_dim]),idx_cand.get_shape()])
return candidates_final_representations
ccrc_model.py 文件源码
python
阅读 51
收藏 0
点赞 0
评论 0
评论列表
文章目录