def extract_argmax_and_embed(embedding, output_projection=None):
"""
Get a loop_function that extracts the previous symbol and embeds it. Used by decoder.
:param embedding: embedding tensor for symbol
:param output_projection: None or a pair (W, B). If provided, each fed previous output will
first be multiplied by W and added B.
:return: A loop function
"""
def loop_function(prev, _):
if output_projection is not None:
prev = tf.matmul(prev, output_projection[0]) + output_projection[1]
prev_symbol = tf.argmax(prev, 1) #?????INDEX
emb_prev = tf.gather(embedding, prev_symbol) #????INDEX???embedding
return emb_prev
return loop_function
# RNN??????
# ???????????????????test,?t???????t+1???s??
评论列表
文章目录