def make_node(self, s_x_):
if s_x_.type.dtype != 'float32':
raise ValueError('Only float32 is allowed')
ctx_name = infer_context_name(s_x_)
s_x = as_gpuarray_variable(s_x_, ctx_name)
return th.Apply(self, [s_x], [s_x.type()])
评论列表
文章目录