def call_seqs(self, x, condition_on, level, *params):
'''Prepares the input for __call__
Args:
x (T.tensor): input
condtion_on (T.tensor or None): tensor to condition recurrence on.
level (int): reccurent level.
*params: list of theano.shared.
Returns:
list: list of scan inputs.
'''
if level == 0:
i_params = self.get_input_args(*params)
a = self.input_net.step_preact(x, *i_params)
else:
i_params = self.get_inter_args(level - 1, *params)
a = self.inter_nets[level - 1].step_preact(x, *i_params)
if condition_on is not None:
a += condition_on
return [a]
评论列表
文章目录