def _transform(self, dataset):
graph_def = self._optimize_for_inference()
input_mapping = self.getInputMapping()
output_mapping = self.getOutputMapping()
graph = tf.Graph()
with tf.Session(graph=graph):
analyzed_df = tfs.analyze(dataset)
out_tnsr_op_names = [tfx.op_name(tnsr_name) for tnsr_name, _ in output_mapping]
# Load graph
tf.import_graph_def(graph_def=graph_def, name='', return_elements=out_tnsr_op_names)
# Feed dict maps from placeholder name to DF column name
feed_dict = {self._getSparkDlOpName(
tnsr_name): col_name for col_name, tnsr_name in input_mapping}
fetches = [tfx.get_tensor(tnsr_name, graph) for tnsr_name in out_tnsr_op_names]
out_df = tfs.map_blocks(fetches, analyzed_df, feed_dict=feed_dict)
# We still have to rename output columns
for tnsr_name, new_colname in output_mapping:
old_colname = tfx.op_name(tnsr_name, graph)
if old_colname != new_colname:
out_df = out_df.withColumnRenamed(old_colname, new_colname)
return out_df
评论列表
文章目录