def create_session_config(log_device_placement=False,
enable_graph_rewriter=False,
gpu_mem_fraction=0.95,
use_tpu=False):
"""The TensorFlow Session config to use."""
if use_tpu:
graph_options = tf.GraphOptions()
else:
if enable_graph_rewriter:
rewrite_options = rewriter_config_pb2.RewriterConfig()
rewrite_options.optimizers.append("pruning")
rewrite_options.optimizers.append("constfold")
rewrite_options.optimizers.append("arithmetic")
rewrite_options.optimizers.append("layout")
graph_options = tf.GraphOptions(rewrite_options=rewrite_options)
else:
graph_options = tf.GraphOptions(
optimizer_options=tf.OptimizerOptions(
opt_level=tf.OptimizerOptions.L1, do_function_inlining=False))
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_mem_fraction)
config = tf.ConfigProto(
allow_soft_placement=True,
graph_options=graph_options,
gpu_options=gpu_options,
log_device_placement=log_device_placement)
return config
评论列表
文章目录