multi_input.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:aboleth 作者: data61 项目源码 文件源码
def input_fn(df):
    """Format the downloaded data."""
    # Creates a dictionary mapping from each continuous feature column name (k)
    # to the values of that column stored in a constant Tensor.
    continuous_cols = [df[k].values for k in CONTINUOUS_COLUMNS]
    X_con = np.stack(continuous_cols).astype(np.float32).T

    # Standardise
    X_con -= X_con.mean(axis=0)
    X_con /= X_con.std(axis=0)

    # Creates a dictionary mapping from each categorical feature column name
    categ_cols = [np.where(pd.get_dummies(df[k]).values)[1][:, np.newaxis]
                  for k in CATEGORICAL_COLUMNS]
    n_values = [np.amax(c) + 1 for c in categ_cols]
    X_cat = np.concatenate(categ_cols, axis=1).astype(np.int32)

    # Converts the label column into a constant Tensor.
    label = df[LABEL_COLUMN].values[:, np.newaxis]

    # Returns the feature columns and the label.
    return X_con, X_cat, n_values, label
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号