def __Recurrent(name, hidden_dims, step_fn, inputs, non_sequences=[], h0s=None):
if not isinstance(inputs, list):
inputs = [inputs]
if not isinstance(hidden_dims, list):
hidden_dims = [hidden_dims]
if h0s is None:
h0s = [None]*len(hidden_dims)
for i in xrange(len(hidden_dims)):
if h0s[i] is None:
h0_unbatched = lib.param(
name + '.h0_' + str(i),
numpy.zeros((hidden_dims[i],), dtype=theano.config.floatX)
)
num_batches = inputs[0].shape[1]
h0s[i] = T.alloc(h0_unbatched, num_batches, hidden_dims[i])
h0s[i] = T.patternbroadcast(h0s[i], [False] * h0s[i].ndim)
outputs, _ = theano.scan(
step_fn,
sequences=inputs,
outputs_info=h0s,
non_sequences=non_sequences
)
return outputs
评论列表
文章目录