def input_fn(batch_size,file_name):
"""
Input function creates feautre and label dict for cross-validation
:param batch_size:
:param file_name:
:return: feature dict
"""
examples_op = tf.contrib.learn.read_batch_examples(
file_name,
batch_size=batch_size,
reader=tf.TextLineReader,
num_threads=5,
num_epochs=1,
randomize_input=False,
parse_fn=lambda x: tf.decode_csv(x, [tf.constant([''], dtype=tf.string)] * len(COLUMNS),field_delim=","))
examples_dict = {}
for i, header in enumerate(COLUMNS):
examples_dict[header] = examples_op[:,i]
feature_cols = {k: tf.string_to_number(examples_dict[k], out_type=tf.float32)
for k in CONTINUOUS_COLUMNS}
feature_cols.update({k: dense_to_sparse(examples_dict[k])
for k in CATEGORICAL_COLUMNS})
label = tf.string_to_number(examples_dict[LABEL_COLUMN], out_type=tf.int32)
return feature_cols, label
wide_deep_evaluate_predict.py 文件源码
python
阅读 38
收藏 0
点赞 0
评论 0
评论列表
文章目录