def rnn_block(input_block, state_length):
"""Get a fully connected RNN block.
The input is concatenated with the state vector and put through a fully
connected layer to get the next state vector.
Args:
input_block: Put each input through this before concatenating it with the
current state vector.
state_length: Length of the RNN state vector.
Returns:
RNN Block (seq of input_block inputs -> output state)
"""
combine_block = ((td.Identity(), input_block) >> td.Concat()
>> td.Function(td.FC(state_length)))
return td.Fold(combine_block, tf.zeros(state_length))
# All characters are lowercase, so subtract 'a' to make them 0-indexed.
评论列表
文章目录