def load_computation(self, computation_decl):
"""
Load a computation and associated storage into the current execution state.
Args:
computation_decl: A ComputationDecl for the computation.
Returns:
An executable for the computation.
"""
self.device_computation = computation_decl.device_computation
exop_block = computation_decl.exop_block
self.start_allocate_computation(computation_decl)
for input_decl in itervalues(computation_decl.op_returns):
self.device_tensor_view(input_decl.tensor_view_decl)
for exop in exop_block:
for input_decl in exop.input_decls:
self.device_tensor_view(input_decl.tensor_view_decl)
for input_decl in exop.write_args:
self.device_tensor_view(input_decl.tensor_view_decl)
for output_decl in exop.output_decls:
self.device_tensor_view(output_decl.tensor_view_decl)
# Make sure we have values for ops that got optimized out
for input_decl in computation_decl.returns.input_decls:
output_decl = input_decl.source_output_decl
if isinstance(output_decl.exop.op, TensorValueOp):
tensor_decl = exop.computation_decl.get_tensor_decl(
op=output_decl.exop.op.value_tensor)
self.device_tensor_view(
tensor_decl.get_tensor_view(output_decl.exop.op.tensor_description()))
else:
self.device_tensor_view(output_decl.tensor_view_decl)
for param in computation_decl.computation_op.parameters:
tensor_decl = computation_decl.get_tensor_decl(op=param.tensor)
self.device_tensor_view(tensor_decl.root_tensor_view_decl)
self.finish_allocate_computation(computation_decl)
self.start_define_computation(computation_decl)
for exop in exop_block:
self.generate_exop(exop)
self.finish_define_computation(computation_decl)
executor = self.finish_load_computation(computation_decl)
self.run_device_tensor_initializations()
return executor
评论列表
文章目录