def step_call(self, x, h0, c0, condition_on, *params):
n_steps = x.shape[0]
n_samples = x.shape[1]
seqs = self.call_seqs(x, condition_on, *params)
outputs_info = [h0, c0]
non_seqs = self.get_recurrent_args(*params)
(h, c), updates = theano.scan(
self._step,
sequences=seqs,
outputs_info=outputs_info,
non_sequences=non_seqs,
name=self.name + '_recurrent_steps',
n_steps=n_steps,
strict=True)
o_params = self.get_output_args(*params)
out_net_out = self.output_net.step_call(h, *o_params)
preact = out_net_out['z']
p = out_net_out['p']
#y = self.output_net.sample(p=p)
return OrderedDict(h=h, p=p, z=preact), updates
评论列表
文章目录