def check_data(X, X_names, Y):
#type checks
assert type(X) is np.ndarray, "type(X) should be numpy.ndarray"
assert type(Y) is np.ndarray, "type(Y) should be numpy.ndarray"
assert type(X_names) is list, "X_names should be a list"
#sizes and uniqueness
N, P = X.shape
assert N > 0, 'X matrix must have at least 1 row'
assert P > 0, 'X matrix must have at least 1 column'
assert len(Y) == N, 'len(Y) should be same as # of rows in X'
assert len(list(set(X_names))) == len(X_names), 'X_names is not unique'
assert len(X_names) == P, 'len(X_names) should be same as # of cols in X'
#X_matrix values
if '(Intercept)' in X_names:
assert all(X[:, X_names.index('(Intercept)')] == 1.0), "'(Intercept)' column should only be composed of 1s"
else:
warnings.warn("there is no column named '(Intercept)' in X_names")
assert np.all(~np.isnan(X)), 'X has nan entries'
assert np.all(~np.isinf(X)), 'X has inf entries'
#Y vector values
assert all((Y == 1)|(Y == -1)), 'Y[i] should = [-1,1] for all i'
if all(Y == 1):
warnings.warn("all Y_i == 1 for all i")
if all(Y == -1):
warnings.warn("all Y_i == -1 for all i")
#TODO (optional) collect warnings and return those?
评论列表
文章目录