def check_numeric_input(self, input_name, input_value):
if type(input_value) is np.ndarray:
if input_value.size == self.P:
setattr(self, input_name, input_value)
elif input_value.size == 1:
setattr(self, input_name, input_value*np.ones(self.P))
else:
raise ValueError("length of %s is %d; should be %d" % (input_name, input_value.size, self.P))
elif type(input_value) is float or type(input_value) is int:
setattr(self, input_name, float(input_value)*np.ones(self.P))
elif type(input_value) is list:
if len(input_value) == self.P:
setattr(self, input_name, np.array([float(x) for x in input_value]))
elif len(input_value) == 1:
setattr(self, input_name, np.array([float(x) for x in input_value]) * np.ones(self.P))
else:
raise ValueError("length of %s is %d; should be %d" % (input_name, len(input_value), self.P))
else:
raise ValueError("user provided %s with an unsupported type" % (input_name))
评论列表
文章目录