model.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:sportsball 作者: jgershen 项目源码 文件源码
def build_model(train_file, attr_file, model_out, algorithm='ridge'):
  classifiers = ['ridge', 'linear', 'lasso', 'rf', 'en']
  if algorithm not in classifiers:
    raise NotImplementedError("only implemented algorithms: " + str(classifiers))

  train_data = pd.read_pickle(train_file)

  attrs = read_attrs(attr_file)
  target_attr = attrs[0]
  usable_attrs = attrs[1:]

  if algorithm == 'ridge':
    clf = Ridge()
  elif algorithm == 'linear':
    clf = LinearRegression()
  elif algorithm == 'lasso':
    clf = Lasso()
  elif algorithm == 'en':
    clf = ElasticNet()
  else:
    clf = RandomForestRegressor()

  logger.debug("Modeling '%s'", target_attr)
  logger.debug("    train set (%d): %s", len(train_data), train_file)
  logger.debug("  Algorithm: %s", algorithm)
  if hasattr(clf, 'coef_'):
    logger.debug('Coefficients:')
    for i,c in enumerate(clf.coef_):
      logger.debug('    %-20s' % usable_attrs[i] + ':', '%20.4f' % c)
  clf.fit(train_data[usable_attrs], train_data[target_attr])

  pickle.dump(clf, open(model_out, 'wb'))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号