def loop_fn_transition(time,previous_output,previous_state,previous_loop_state):
#print time
elements_finished = (time >= decoder_lengths)
def next_input():
prev_out_with_weights = tf.matmul(previous_output,w['score'])
prev_out_with_weights = tf.reshape(prev_out_with_weights,[-1,final_hidden_units,1])
score = tf.matmul(encoder_outputs,prev_out_with_weights)
score = tf.reshape(score,[-1,num_steps])
attention = tf.nn.softmax(score)
attention = tf.reshape(attention,[-1,1,num_steps])
ct = tf.matmul(attention,encoder_outputs)
ct = tf.reshape(ct,[-1,final_hidden_units])
ctht = tf.concat((ct,previous_output),1)
ht_dash = tf.nn.tanh(tf.add(tf.matmul(ctht,w['hdash']),b['hdash']))
pred = tf.nn.softmax(tf.add(tf.matmul(ctht,w['decoder']),b['decoder']))
prediction = tf.argmax(pred,axis=1)
inputn = tf.nn.embedding_lookup(embeddings,prediction)
return inputn
finished = tf.reduce_all(elements_finished)
next_input = tf.cond(finished,lambda:pad_embedded,next_input)
state = previous_state
output = previous_output
#print output.shape
loop_state = None
return (elements_finished,
next_input,
state,
output,
loop_state)
# In[31]:
Seq2Seq_model_for_TextSummarizer-600L.py 文件源码
python
阅读 33
收藏 0
点赞 0
评论 0
评论列表
文章目录