def __call__(self, inputs):
ret_outputs = []
if isinstance(inputs[-1], Number):
self.is_train = inputs[-1]
inputs = inputs[:-1]
for x in self.output:
bind_values = dfs_get_bind_values(x)
data = {k.name: v for k, v in zip(self.inputs, inputs)}
data = dict(data, **bind_values)
args = x.symbol.list_arguments()
data_shapes = {k.name: v.shape for k, v in zip(self.inputs, inputs) if k.name in args}
executor = x.symbol.simple_bind(mx.cpu(), grad_req='null', **data_shapes)
for v in executor.arg_dict:
if v in data:
executor.arg_dict[v][:] = data[v]
outputs = executor.forward(is_train=self.is_train)
ret_outputs.append(outputs[0].asnumpy())
return ret_outputs
评论列表
文章目录