def queue_transform(feature_strengths, feature_vects, return_strengths=False):
"""
Process features according to a "fragmented queue", where each timestep
gets a size-1 window onto a feature queue. Effectively,
feature_strengths gives how much to push onto queue
feature_vects gives what to push on
pop weights are tied to feature_strengths
output is a size-1 peek (without popping)
Parameters:
- feature_strengths: float32 tensor of shape (batch, push_timestep) in [0,1]
- feature_vects: float32 tensor of shape (batch, push_timestep, feature_dim)
Returns:
- peek_vects: float32 tensor of shape (batch, timestep, feature_dim)
"""
n_batch, n_time, n_feature = feature_vects.shape
cum_sum_str = T.extra_ops.cumsum(feature_strengths, 1)
# We will be working in (batch, timestep, push_timestep)
# For each timestep, if we subtract out the sum of pushes before that timestep
# and then cap to 0-1 we get the cumsums for just the features active in that
# timestep
timestep_adjustments = T.shape_padright(cum_sum_str - feature_strengths)
push_time_cumsum = T.shape_padaxis(cum_sum_str, 1)
relative_cumsum = push_time_cumsum - timestep_adjustments
capped_cumsum = T.minimum(T.maximum(relative_cumsum, 0), 1)
# Now we can recover the peek strengths by taking a diff
shifted = T.concatenate([T.zeros((n_batch, n_time, 1)), capped_cumsum[:,:,:-1]],2)
peek_strengths = capped_cumsum-shifted
# Peek strengths is now (batch, timestep, push_timestep)
result = T.batched_dot(peek_strengths, feature_vects)
if return_strengths:
return peek_strengths, result
else:
return result
评论列表
文章目录