def apply_input_bias_and_extract_argmax_fn_factory(input_bias):
"""
:param encoder_inputs: list of length equal to the input bucket
length of 1-D tensors (of length equal to the batch size) whose
elements consist of the token index of each sample in the batch
at a given index in the input.
:return:
"""
def fn_factory(embedding, output_projection=None, update_embedding=True):
"""Get a loop_function that extracts the previous symbol and embeds it.
Args:
embedding: embedding tensor for symbols.
output_projection: None or a pair (W, B). If provided, each fed previous
output will first be multiplied by W and added B.
update_embedding: Boolean; if False, the gradients will not propagate
through the embeddings.
Returns:
A loop function.
"""
def loop_function(prev, _):
prev = project_and_apply_input_bias(prev, output_projection,
input_bias)
prev_symbol = math_ops.argmax(prev, 1)
# Note that gradients will not propagate through the second
# parameter of embedding_lookup.
emb_prev = embedding_ops.embedding_lookup(embedding, prev_symbol)
if not update_embedding:
emb_prev = array_ops.stop_gradient(emb_prev)
return emb_prev, prev_symbol
return loop_function
return fn_factory
text_corrector_models.py 文件源码
python
阅读 21
收藏 0
点赞 0
评论 0
评论列表
文章目录