def for_loop(step_function, inputs, initial_hidden_states, go_backwards):
"""
inputs: time axis must be first
"""
results = theano.scan(
step_function,
sequences=inputs,
outputs_info=initial_hidden_states,
go_backwards=go_backwards)[0] #screw the updates
#when results has length 1, it is not returned as a list. wrap it
if (isinstance(results, list)==False):
results = [results]
#put the batch axis back in front
results = [dimshuffle(tensor, [1,0]+[x for x in xrange(2, tensor.ndim)])
for tensor in results]
return results
评论列表
文章目录