def testStringToTFIDF(self):
def preprocessing_fn(inputs):
inputs_as_ints = tft.string_to_int(tf.string_split(inputs['a']))
out_index, out_values = tft.tfidf(inputs_as_ints, 6)
return {
'tf_idf': out_values,
'index': out_index
}
input_data = [{'a': 'hello hello world'},
{'a': 'hello goodbye hello world'},
{'a': 'I like pie pie pie'}]
input_schema = dataset_metadata.DatasetMetadata({
'a': sch.ColumnSchema(tf.string, [], sch.FixedColumnRepresentation())
})
# IDFs
# hello = log(4/3) = 0.28768
# world = log(4/3)
# goodbye = log(4/2) = 0.69314
# I = log(4/2)
# like = log(4/2)
# pie = log(4/2)
log_4_over_2 = 1.69314718056
log_4_over_3 = 1.28768207245
expected_transformed_data = [{
'tf_idf': [(2/3)*log_4_over_3, (1/3)*log_4_over_3],
'index': [0, 2]
}, {
'tf_idf': [(2/4)*log_4_over_3, (1/4)*log_4_over_3, (1/4)*log_4_over_2],
'index': [0, 2, 4]
}, {
'tf_idf': [(3/5)*log_4_over_2, (1/5)*log_4_over_2, (1/5)*log_4_over_2],
'index': [1, 3, 5]
}]
expected_transformed_schema = dataset_metadata.DatasetMetadata({
'tf_idf': sch.ColumnSchema(tf.float32, [None],
sch.ListColumnRepresentation()),
'index': sch.ColumnSchema(tf.int64, [None],
sch.ListColumnRepresentation())
})
self.assertAnalyzeAndTransformResults(
input_data, input_schema, preprocessing_fn, expected_transformed_data,
expected_transformed_schema)
评论列表
文章目录