def _get_submodule_code(op):
parameters = ', '.join('%s &%s' % (_dtype_to_ctype[t], name)
for i, (name, t)
in enumerate(zip(op.param_names, op.types)))
typedecl = ''.join(('typedef %s in%d_type;\n' % (_dtype_to_ctype[t], i))
for i, t in enumerate(op.types[:op.nin]))
typedecl += ''.join(('typedef %s out%d_type;\n' % (_dtype_to_ctype[t], i))
for i, t in enumerate(op.types[op.nin:]))
module_code = string.Template('''
__device__ void ${name}(${parameters}) {
${typedecl}
${operation};
}
''').substitute(
name=op.name,
parameters=parameters,
operation=op.operation,
typedecl=typedecl)
return module_code + '\n'
评论列表
文章目录