def _step(time, sequence_length, min_sequence_length, max_sequence_length, zero_logit, generate_logit):
# Step 1: determine whether we need to call_cell or not
empty_update = lambda: zero_logit
logit = control_flow_ops.cond(
time < max_sequence_length, generate_logit, empty_update)
# Step 2: determine whether we need to copy through state and/or outputs
existing_logit = lambda: logit
def copy_through():
# Use broadcasting select to determine which values should get
# the previous state & zero output, and which values should get
# a calculated state & output.
copy_cond = (time >= sequence_length)
return math_ops.select(copy_cond, zero_logit, logit)
logit = control_flow_ops.cond(
time < min_sequence_length, existing_logit, copy_through)
logit.set_shape(logit.get_shape())
return logit
评论列表
文章目录