def main(argv=None):
"""Run a Tensorflow model on the Criteo dataset."""
env = json.loads(os.environ.get('TF_CONFIG', '{}'))
# First find out if there's a task value on the environment variable.
# If there is none or it is empty define a default one.
task_data = env.get('task') or {'type': 'master', 'index': 0}
argv = sys.argv if argv is None else argv
args = create_parser().parse_args(args=argv[1:])
trial = task_data.get('trial')
if trial is not None:
output_dir = os.path.join(args.output_path, trial)
else:
output_dir = args.output_path
# Do only evaluation if instructed so, or call Experiment's run.
if args.eval_only_summary_filename:
experiment = get_experiment_fn(args)(output_dir)
# Note that evaluation here will appear as 'one_pass' in tensorboard.
results = experiment.evaluate(delay_secs=0)
# Converts numpy types to native types for json dumps.
json_out = json.dumps(
{key: value.tolist() for key, value in results.iteritems()})
with tf.Session():
tf.write_file(args.eval_only_summary_filename, json_out).run()
else:
learn_runner.run(experiment_fn=get_experiment_fn(args),
output_dir=output_dir)
python类write_file()的实例源码
def write_record(self, sess=None):
with tf.name_scope('Dataset_Classification_Writer') as scope:
if sess is None:
self.sess = tf.get_default_session()
else:
self.sess = sess
im_pth = tf.placeholder(tf.string)
image_raw = tf.read_file(im_pth)
image_pix = tf.image.convert_image_dtype(tf.image.decode_image(image_raw), tf.float32)
total_images = len(self.shuffled_images)
mean_assign = tf.assign(self.dataset_mean, self.dataset_mean + image_pix/total_images)
print('\t\t Constructing Database')
self.mean_header_proto.Image_headers.image_count = total_images
for index , image_container in enumerate(self.shuffled_images):
printProgressBar(index+1, total_images)
im_rw = self.sess.run([image_raw, mean_assign],feed_dict={im_pth: image_container.image_path})
self.Param_dict[self._Label_handle] = self._int64_feature(image_container.image_data)
self.Param_dict[self._Image_handle] = self._bytes_feature(im_rw[0])
self.Param_dict[self._Image_name] = self._bytes_feature(str.encode(image_container.image_name))
example = tf.train.Example(features=tf.train.Features(feature=self.Param_dict))
self._Writer.write(example.SerializeToString())
#ADD TO MEAN IMAGE
#ENCODE MEAN AND STORE IT
self.dataset_mean = tf.image.convert_image_dtype(self.dataset_mean, tf.uint8)
encoded_mean = tf.image.encode_png(self.dataset_mean)
self.mean_header_proto.mean_data = encoded_mean.eval()
with open(self.dataset_name+'_mean.proto','wb') as mean_proto_file:
mean_proto_file.write(self.mean_header_proto.SerializeToString())
self.sess.run([tf.write_file(self.dataset_name+'_mean.png', encoded_mean.eval())])
self._Writer.close()
#From: https://stackoverflow.com/questions/3173320/text-progress-bar-in-the-console
def build_output(model):
"""
save translation result to FLAGS.target_image_path.
"""
images = tf.concat(
[model['source_images'], model['output_images']], axis=2)
images = tf.reshape(images, [FLAGS.batch_size * 256, 512, 3])
images = tf.saturate_cast(images * 127.5 + 127.5, tf.uint8)
images = tf.image.encode_png(images)
return tf.write_file(FLAGS.target_image_path, images)
def decode(self, ids):
"""Transform a sequence of int ids into an image file.
Args:
ids: list of integers to be converted.
Returns:
Path to the temporary file where the image was saved.
Raises:
ValueError: if the ids are not of the appropriate size.
"""
_, tmp_file_path = tempfile.mkstemp()
length = self._height * self._width * self._channels
if len(ids) != length:
raise ValueError("Length of ids (%d) must be height (%d) x width (%d) x "
"channels (%d); %d != %d.\n Ids: %s"
% (len(ids), self._height, self._width, self._channels,
len(ids), length, " ".join([str(i) for i in ids])))
with tf.Graph().as_default():
raw = tf.constant(ids, dtype=tf.uint8)
img = tf.reshape(raw, [self._height, self._width, self._channels])
png = tf.image.encode_png(img)
op = tf.write_file(tmp_file_path, png)
with tf.Session() as sess:
sess.run(op)
return tmp_file_path
def write_record(self, sess=None):
with tf.name_scope('Dataset_ImageSeqGen_Writer') as scope:
if sess is None:
self.sess = tf.get_default_session()
else:
self.sess = sess
im_pth = tf.placeholder(tf.string)
image_raw = tf.read_file(im_pth)
image_pix = tf.image.convert_image_dtype(tf.image.decode_image(image_raw), tf.float32)
total_images = len(self.shuffled_images)
mean_assign = tf.assign(self.dataset_mean, self.dataset_mean + image_pix/total_images)
print('\t\t Constructing Database')
self.mean_header_proto.Image_headers.image_count = total_images
for index , image_container in enumerate(self.shuffled_images):
print(total_images)
printProgressBar(index+1, total_images)
im_rw = self.sess.run([image_raw, mean_assign],feed_dict={im_pth: image_container.image_path})
self.Param_dict[self._Seq_handle] = self._bytes_feature(str.encode(image_container.image_data))
self.Param_dict[self._Seq_mask] = self._bytes_feature(str.encode(image_container.seq_mask))
self.Param_dict[self._Image_handle] = self._bytes_feature(im_rw[0])
self.Param_dict[self._Image_name] = self._bytes_feature(str.encode(image_container.image_path))
example = tf.train.Example(features=tf.train.Features(feature=self.Param_dict))
self._Writer.write(example.SerializeToString())
#ADD TO MEAN IMAGE
#ENCODE MEAN AND STORE IT
self.dataset_mean = tf.image.convert_image_dtype(self.dataset_mean, tf.uint8)
encoded_mean = tf.image.encode_png(self.dataset_mean)
self.mean_header_proto.mean_data = encoded_mean.eval()
with open(self.dataset_name+'_mean.proto','wb') as mean_proto_file:
mean_proto_file.write(self.mean_header_proto.SerializeToString())
self.sess.run([tf.write_file(self.dataset_name+'_mean.png', encoded_mean.eval())])
self._Writer.close()
#From: https://stackoverflow.com/questions/3173320/text-progress-bar-in-the-console
def merge_sm_with_tf(isomap_lists, confidence_lists, output_list):
import tensorflow as tf
import cnn_tf_graphs
from shutil import copyfile
#zipped_input = zip(isomap_lists, confidence_lists, output_list)
#zipped_input.sort(key=lambda x: len(x[0]))
#isomap_lists, confidence_lists, output_list = zip(*zipped_input)
sorted_idx_list = sorted(range(len(isomap_lists)), key=lambda x: len(isomap_lists[x]))
#print (sorted_idx_list)
isomap_lists = [isomap_lists[i] for i in sorted_idx_list]
confidence_lists = [confidence_lists[i] for i in sorted_idx_list]
output_list = [output_list[i] for i in sorted_idx_list]
#print ('HHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHH')
#for i in range(len(isomap_lists)):
# print (isomap_lists[i])
# print (confidence_lists[i])
# print (output_list[i])
#isomap_lists.sort(key=len)
merge_legth = -1
sess = None
for j, isomap_list in enumerate(isomap_lists):
with tf.Graph().as_default():
if len(isomap_list) == 0:
continue
elif len(isomap_list) ==1:
copyfile(isomap_list[0],output_list[j])
else:
if len(isomap_list) != merge_legth:
if sess:
sess.close()
placeholders = []
outpath = tf.placeholder(tf.string)
for i in range(len(isomap_list)):
colour = tf.placeholder(tf.float32, shape=(1, ISOMAP_SIZE, ISOMAP_SIZE, 3))
conf = tf.placeholder(tf.float32, shape=(1, ISOMAP_SIZE, ISOMAP_SIZE, 1))
placeholders.append([colour, conf])
merged = tf.squeeze(cnn_tf_graphs.merge_isomaps_softmax(placeholders))
merged_uint8 = tf.cast(merged, tf.uint8)
encoded = tf.image.encode_png(merged_uint8)
write_file_op = tf.write_file(outpath, encoded)
merge_legth = len(isomap_list)
sess = tf.Session()
print ('merging',merge_legth,'images (max',len(isomap_lists[-1]),') idx',j,'of',len(isomap_lists))
feed_dict = {}
for i in range(len(isomap_list)):
feed_dict[placeholders[i][0]] = np.expand_dims(cv2.imread(isomap_list[i], cv2.IMREAD_UNCHANGED)[:,:,:3].astype(np.float32)[:,:,::-1], axis=0)
feed_dict[placeholders[i][1]] = np.expand_dims(np.load(confidence_lists[j][i]).astype(np.float32), axis=0)
feed_dict[outpath] = output_list[j]
sess.run(write_file_op, feed_dict=feed_dict)
def transfer():
"""
"""
if tf.gfile.IsDirectory(FLAGS.ckpt_path):
ckpt_source_path = tf.train.latest_checkpoint(FLAGS.ckpt_path)
elif tf.gfile.Exists(FLAGS.ckpt_path):
ckpt_source_path = FLAGS.ckpt_path
else:
assert False, 'bad checkpoint'
assert tf.gfile.Exists(FLAGS.content_path), 'bad content_path'
assert not tf.gfile.IsDirectory(FLAGS.content_path), 'bad content_path'
image_contents = build_contents_reader()
network = build_style_transfer_network(image_contents, training=False)
#
shape = tf.shape(network['image_styled'])
new_w = shape[1] - 2 * FLAGS.padding
new_h = shape[2] - 2 * FLAGS.padding
image_styled = tf.slice(
network['image_styled'],
[0, FLAGS.padding, FLAGS.padding, 0],
[-1, new_w, new_h, -1])
image_styled = tf.squeeze(image_styled, [0])
image_styled = image_styled * 127.5 + 127.5
image_styled = tf.reverse(image_styled, [2])
image_styled = tf.saturate_cast(image_styled, tf.uint8)
image_styled = tf.image.encode_jpeg(image_styled)
image_styled_writer = tf.write_file(FLAGS.styled_path, image_styled)
with tf.Session() as session:
tf.train.Saver().restore(session, ckpt_source_path)
# make dataset reader work
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
session.run(image_styled_writer)
coord.request_stop()
coord.join(threads)