def testBasic(self):
base_path = tf.test.test_src_dir_path(
"contrib/session_bundle/example/half_plus_two/00000123")
tf.reset_default_graph()
sess, meta_graph_def = session_bundle.load_session_bundle_from_path(
base_path, target="", config=tf.ConfigProto(device_count={"CPU": 2}))
self.assertTrue(sess)
asset_path = os.path.join(base_path, constants.ASSETS_DIRECTORY)
with sess.as_default():
path1, path2 = sess.run(["filename1:0", "filename2:0"])
self.assertEqual(
compat.as_bytes(os.path.join(asset_path, "hello1.txt")), path1)
self.assertEqual(
compat.as_bytes(os.path.join(asset_path, "hello2.txt")), path2)
collection_def = meta_graph_def.collection_def
signatures_any = collection_def[constants.SIGNATURES_KEY].any_list.value
self.assertEquals(len(signatures_any), 1)
signatures = manifest_pb2.Signatures()
signatures_any[0].Unpack(signatures)
self._checkRegressionSignature(signatures, sess)
self._checkNamedSigantures(signatures, sess)
python类as_bytes()的实例源码
def main():
global writer
config = load_config()
# todo: factor out common logic
logdir = os.environ["LOGDIR"]
writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(logdir+'/events'))
if config.task_type == 'worker':
run_worker()
elif config.task_type == 'ps':
run_ps()
else:
assert False, "Unknown task type "+str(config.task_type)
writer.Close()
def _write_assets(assets_directory, assets_filename):
"""??????? hall_plus_two ???????????
Args:
- assets_directory: ?????????
- assets_filename: ???????
Returns:
????????
"""
if not file_io.file_exists(assets_directory):
file_io.recursive_create_dir(assets_directory)
path = os.path.join(
compat.as_bytes(assets_directory),
compat.as_bytes(assets_filename))
file_io.write_string_to_file(path, "asset-file-contents")
return path
experiment.py 文件源码
项目:DeepLearning_VirtualReality_BigData_Project
作者: rashmitripathi
项目源码
文件源码
阅读 34
收藏 0
点赞 0
评论 0
def _maybe_export(self, eval_result): # pylint: disable=unused-argument
"""Export the Estimator using export_fn, if defined."""
export_dir_base = os.path.join(
compat.as_bytes(self._estimator.model_dir),
compat.as_bytes("export"))
export_results = []
for strategy in self._export_strategies:
# TODO(soergel): possibly, allow users to decide whether to export here
# based on the eval_result (e.g., to keep the best export).
export_results.append(
strategy.export(
self._estimator,
os.path.join(
compat.as_bytes(export_dir_base),
compat.as_bytes(strategy.name))))
return export_results
saved_model_export_utils.py 文件源码
项目:DeepLearning_VirtualReality_BigData_Project
作者: rashmitripathi
项目源码
文件源码
阅读 21
收藏 0
点赞 0
评论 0
def get_timestamped_export_dir(export_dir_base):
"""Builds a path to a new subdirectory within the base directory.
Each export is written into a new subdirectory named using the
current time. This guarantees monotonically increasing version
numbers even across multiple runs of the pipeline.
The timestamp used is the number of seconds since epoch UTC.
Args:
export_dir_base: A string containing a directory to write the exported
graph and checkpoints.
Returns:
The full path of the new subdirectory (which is not actually created yet).
"""
export_timestamp = int(time.time())
export_dir = os.path.join(
compat.as_bytes(export_dir_base),
compat.as_bytes(str(export_timestamp)))
return export_dir
# create a simple parser that pulls the export_version from the directory.
saved_model_export_utils_test.py 文件源码
项目:DeepLearning_VirtualReality_BigData_Project
作者: rashmitripathi
项目源码
文件源码
阅读 28
收藏 0
点赞 0
评论 0
def test_get_most_recent_export(self):
export_dir_base = tempfile.mkdtemp() + "export/"
gfile.MkDir(export_dir_base)
_create_test_export_dir(export_dir_base)
_create_test_export_dir(export_dir_base)
_create_test_export_dir(export_dir_base)
export_dir_4 = _create_test_export_dir(export_dir_base)
(most_recent_export_dir, most_recent_export_version) = (
saved_model_export_utils.get_most_recent_export(export_dir_base))
self.assertEqual(compat.as_bytes(export_dir_4),
compat.as_bytes(most_recent_export_dir))
self.assertEqual(compat.as_bytes(export_dir_4),
os.path.join(compat.as_bytes(export_dir_base),
compat.as_bytes(
str(most_recent_export_version))))
def visualize_graph_in_tfboard(filename, output='./log'):
with tf.Session() as sess:
model_filename = filename
with gfile.FastGFile(model_filename, 'rb') as f:
data = compat.as_bytes(f.read())
sm = saved_model_pb2.SavedModel()
sm.ParseFromString(data)
if 1 != len(sm.meta_graphs):
print('More than one graph found. Not sure which to write')
sys.exit(1)
g_in = tf.import_graph_def(sm.meta_graphs[0].graph_def)
train_writer = tf.summary.FileWriter(output)
train_writer.add_graph(sess.graph)
print("Please execute `tensorboard --logdir {}` to view graph".format(output))
def __init__(self, dir):
os.makedirs(dir, exist_ok=True)
self.dir = dir
self.step = 1
prefix = 'events'
path = osp.join(osp.abspath(dir), prefix)
import tensorflow as tf
from tensorflow.python import pywrap_tensorflow
from tensorflow.core.util import event_pb2
from tensorflow.python.util import compat
self.tf = tf
self.event_pb2 = event_pb2
self.pywrap_tensorflow = pywrap_tensorflow
self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path))
def __init__(self, dir):
os.makedirs(dir, exist_ok=True)
self.dir = dir
self.step = 1
prefix = 'events'
path = osp.join(osp.abspath(dir), prefix)
import tensorflow as tf
from tensorflow.python import pywrap_tensorflow
from tensorflow.core.util import event_pb2
from tensorflow.python.util import compat
self.tf = tf
self.event_pb2 = event_pb2
self.pywrap_tensorflow = pywrap_tensorflow
self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path))
def reset(target, containers=None, config=None):
if target is not None:
target = compat.as_bytes(target)
if containers is not None:
containers = [compat.as_bytes(c) for c in containers]
else:
containers = []
tf_session.TF_Reset(target, containers, config)
def _name_list(tensor_list):
"""Utility function for transitioning to the new session API.
Args:
tensor_list: a list of `Tensor`s.
Returns:
A list of each `Tensor`s name (as byte arrays).
"""
return [compat.as_bytes(t.name) for t in tensor_list]
def reset(target, containers=None, config=None):
"""Resets resource containers on `target`, and close all connected sessions.
A resource container is distributed across all workers in the
same cluster as `target`. When a resource container on `target`
is reset, resources associated with that container will be cleared.
In particular, all Variables in the container will become undefined:
they lose their values and shapes.
NOTE:
(i) reset() is currently only implemented for distributed sessions.
(ii) Any sessions on the master named by `target` will be closed.
If no resource containers are provided, all containers are reset.
Args:
target: The execution engine to connect to.
containers: A list of resource container name strings, or `None` if all of
all the containers are to be reset.
config: (Optional.) Protocol buffer with configuration options.
Raises:
tf.errors.OpError: Or one of its subclasses if an error occurs while
resetting containers.
"""
if target is not None:
target = compat.as_bytes(target)
if containers is not None:
containers = [compat.as_bytes(c) for c in containers]
else:
containers = []
tf_session.TF_Reset(target, containers, config)
def testBasic(self):
base_path = tf.test.test_src_dir_path(
"contrib/session_bundle/example/half_plus_two/00000123")
tf.reset_default_graph()
sess, meta_graph_def = session_bundle.load_session_bundle_from_path(
base_path, target="", config=tf.ConfigProto(device_count={"CPU": 2}))
self.assertTrue(sess)
asset_path = os.path.join(base_path, constants.ASSETS_DIRECTORY)
with sess.as_default():
path1, path2 = sess.run(["filename1:0", "filename2:0"])
self.assertEqual(
compat.as_bytes(os.path.join(asset_path, "hello1.txt")), path1)
self.assertEqual(
compat.as_bytes(os.path.join(asset_path, "hello2.txt")), path2)
collection_def = meta_graph_def.collection_def
signatures_any = collection_def[constants.SIGNATURES_KEY].any_list.value
self.assertEquals(len(signatures_any), 1)
signatures = manifest_pb2.Signatures()
signatures_any[0].Unpack(signatures)
default_signature = signatures.default_signature
input_name = default_signature.regression_signature.input.tensor_name
output_name = default_signature.regression_signature.output.tensor_name
y = sess.run([output_name], {input_name: np.array([[0], [1], [2], [3]])})
# The operation is y = 0.5 * x + 2
self.assertEqual(y[0][0], 2)
self.assertEqual(y[0][1], 2.5)
self.assertEqual(y[0][2], 3)
self.assertEqual(y[0][3], 3.5)
def gfile_copy_callback(files_to_copy, export_dir_path):
"""Callback to copy files using `gfile.Copy` to an export directory.
This method is used as the default `assets_callback` in `Exporter.init` to
copy assets from the `assets_collection`. It can also be invoked directly to
copy additional supplementary files into the export directory (in which case
it is not a callback).
Args:
files_to_copy: A dictionary that maps original file paths to desired
basename in the export directory.
export_dir_path: Directory to copy the files to.
"""
logging.info("Write assest into: %s using gfile_copy.", export_dir_path)
gfile.MakeDirs(export_dir_path)
for source_filepath, basename in files_to_copy.items():
new_path = os.path.join(
compat.as_bytes(export_dir_path), compat.as_bytes(basename))
logging.info("Copying asset %s to path %s.", source_filepath, new_path)
if gfile.Exists(new_path):
# Guard against being restarted while copying assets, and the file
# existing and being in an unknown state.
# TODO(b/28676216): Do some file checks before deleting.
logging.info("Removing file %s.", new_path)
gfile.Remove(new_path)
gfile.Copy(source_filepath, new_path)
def gfile_copy_callback(files_to_copy, export_dir_path):
"""Callback to copy files using `gfile.Copy` to an export directory.
This method is used as the default `assets_callback` in `Exporter.init` to
copy assets from the `assets_collection`. It can also be invoked directly to
copy additional supplementary files into the export directory (in which case
it is not a callback).
Args:
files_to_copy: A dictionary that maps original file paths to desired
basename in the export directory.
export_dir_path: Directory to copy the files to.
"""
logging.info("Write assets into: %s using gfile_copy.", export_dir_path)
gfile.MakeDirs(export_dir_path)
for source_filepath, basename in files_to_copy.items():
new_path = os.path.join(
compat.as_bytes(export_dir_path), compat.as_bytes(basename))
logging.info("Copying asset %s to path %s.", source_filepath, new_path)
if gfile.Exists(new_path):
# Guard against being restarted while copying assets, and the file
# existing and being in an unknown state.
# TODO(b/28676216): Do some file checks before deleting.
logging.info("Removing file %s.", new_path)
gfile.Remove(new_path)
gfile.Copy(source_filepath, new_path)
def run_benchmark(sess, init_op, add_op):
"""Returns MB/s rate of addition."""
logdir=FLAGS.logdir_prefix+'/'+FLAGS.name
os.system('mkdir -p '+logdir)
# TODO: make events follow same format as eager writer
writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(logdir+'/events'))
filename = compat.as_text(writer.FileName())
training_util.get_or_create_global_step()
sess.run(init_op)
for step in range(FLAGS.iters):
start_time = time.time()
for i in range(FLAGS.iters_per_step):
sess.run(add_op.op)
elapsed_time = time.time() - start_time
rate = float(FLAGS.iters)*FLAGS.data_mb/elapsed_time
event = make_event('rate', rate, step)
writer.WriteEvent(event)
writer.Flush()
writer.Close()
# add event
def __init__(self, dir, prefix):
self.dir = dir
# Start at 1, because EvWriter automatically generates an object with
# step = 0.
self.step = 1
self.evwriter = pywrap_tensorflow.EventsWriter(
compat.as_bytes(os.path.join(dir, prefix)))
def __init__(self, dir, prefix):
self.dir = dir
self.step = 1 # Start at 1, because EvWriter automatically generates an object with step=0
self.evwriter = pywrap_tensorflow.EventsWriter(compat.as_bytes(os.path.join(dir, prefix)))
def __init__(self, dir):
os.makedirs(dir, exist_ok=True)
self.dir = dir
self.step = 1
prefix = 'events'
path = osp.join(osp.abspath(dir), prefix)
import tensorflow as tf
from tensorflow.python import pywrap_tensorflow
from tensorflow.core.util import event_pb2
from tensorflow.python.util import compat
self.tf = tf
self.event_pb2 = event_pb2
self.pywrap_tensorflow = pywrap_tensorflow
self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path))
def __init__(self, dir):
os.makedirs(dir, exist_ok=True)
self.dir = dir
self.step = 1
prefix = 'events'
path = osp.join(osp.abspath(dir), prefix)
import tensorflow as tf
from tensorflow.python import pywrap_tensorflow
from tensorflow.core.util import event_pb2
from tensorflow.python.util import compat
self.tf = tf
self.event_pb2 = event_pb2
self.pywrap_tensorflow = pywrap_tensorflow
self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path))
experiment_test.py 文件源码
项目:DeepLearning_VirtualReality_BigData_Project
作者: rashmitripathi
项目源码
文件源码
阅读 26
收藏 0
点赞 0
评论 0
def export_savedmodel(self, export_dir_base, serving_input_fn, **kwargs):
tf_logging.info('export_savedmodel called with args: %s, %s, %s' %
(export_dir_base, serving_input_fn, kwargs))
self.export_count += 1
return os.path.join(
compat.as_bytes(export_dir_base), compat.as_bytes('bogus_timestamp'))
def test_stale_asset_collections_are_cleaned(self):
vocabulary_file = os.path.join(
compat.as_bytes(test.get_temp_dir()), compat.as_bytes('asset'))
file_io.write_string_to_file(vocabulary_file, 'foo bar baz')
export_path = os.path.join(tempfile.mkdtemp(), 'export')
# create a SavedModel including assets
with tf.Graph().as_default():
with tf.Session().as_default() as session:
input_string = tf.placeholder(tf.string)
# Map string through a table loaded from an asset file
table = lookup.index_table_from_file(
vocabulary_file, num_oov_buckets=12, default_value=12)
output = table.lookup(input_string)
inputs = {'input': input_string}
outputs = {'output': output}
saved_transform_io.write_saved_transform_from_session(
session, inputs, outputs, export_path)
# Load it and save it again repeatedly, verifying that the asset collections
# remain valid.
for _ in [1, 2, 3]:
with tf.Graph().as_default() as g:
with tf.Session().as_default() as session:
input_string = tf.constant('dog')
inputs = {'input': input_string}
outputs = saved_transform_io.apply_saved_transform(export_path,
inputs)
self.assertEqual(
1, len(g.get_collection(ops.GraphKeys.ASSET_FILEPATHS)))
self.assertEqual(
0, len(g.get_collection(tf.saved_model.constants.ASSETS_KEY)))
# Check that every ASSET_FILEPATHS refers to a Tensor in the graph.
# If not, get_tensor_by_name() raises KeyError.
for asset_path in g.get_collection(ops.GraphKeys.ASSET_FILEPATHS):
tensor_name = asset_path.name
g.get_tensor_by_name(tensor_name)
export_path = os.path.join(tempfile.mkdtemp(), 'export')
saved_transform_io.write_saved_transform_from_session(
session, inputs, outputs, export_path)
def _do_run(self, handle, target_list, fetch_list, feed_dict,
options, run_metadata):
"""Runs a step based on the given fetches and feeds.
Args:
handle: a handle for partial_run. None if this is just a call to run().
target_list: A list of operations to be run, but not fetched.
fetch_list: A list of tensors to be fetched.
feed_dict: A dictionary that maps tensors to numpy ndarrays.
options: A (pointer to a) [`RunOptions`] protocol buffer, or None
run_metadata: A (pointer to a) [`RunMetadata`] protocol buffer, or None
Returns:
A list of numpy ndarrays, corresponding to the elements of
`fetch_list`. If the ith element of `fetch_list` contains the
name of an operation, the first Tensor output of that operation
will be returned for that element.
Raises:
tf.errors.OpError: Or one of its subclasses on error.
"""
if self._created_with_new_api:
# pylint: disable=protected-access
feeds = dict((t._as_tf_output(), v) for t, v in feed_dict.items())
fetches = [t._as_tf_output() for t in fetch_list]
targets = [op._c_op for op in target_list]
# pylint: enable=protected-access
else:
feeds = dict((compat.as_bytes(t.name), v) for t, v in feed_dict.items())
fetches = _name_list(fetch_list)
targets = _name_list(target_list)
def _run_fn(session, feed_dict, fetch_list, target_list, options,
run_metadata):
# Ensure any changes to the graph are reflected in the runtime.
self._extend_graph()
with errors.raise_exception_on_not_ok_status() as status:
if self._created_with_new_api:
return tf_session.TF_SessionRun_wrapper(
session, options, feed_dict, fetch_list, target_list,
run_metadata, status)
else:
return tf_session.TF_Run(session, options,
feed_dict, fetch_list, target_list,
status, run_metadata)
def _prun_fn(session, handle, feed_dict, fetch_list):
assert not self._created_with_new_api, ('Partial runs don\'t work with '
'C API')
if target_list:
raise RuntimeError('partial_run() requires empty target_list.')
with errors.raise_exception_on_not_ok_status() as status:
return tf_session.TF_PRun(session, handle, feed_dict, fetch_list,
status)
if handle is None:
return self._do_call(_run_fn, self._session, feeds, fetches, targets,
options, run_metadata)
else:
return self._do_call(_prun_fn, self._session, handle, feeds, fetches)
def createImageLists(imageDir, testingPercentage, validationPercventage):
if not gfile.Exists(imageDir):
print("Image dir'" + imageDir +"'not found.'")
return None
result = {}
subDirs = [x[0] for x in gfile.Walk(imageDir)]
isRootDir = True
for subDir in subDirs:
if isRootDir:
isRootDir = False
continue
extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
fileList = []
dirName = os.path.basename(subDir)
if dirName == imageDir:
continue
print("Looking for images in '" + dirName + "'")
for extension in extensions:
fileGlob = os.path.join(imageDir, dirName, '*.' + extension)
fileList.extend(gfile.Glob(fileGlob))
if not fileList:
print('No file found')
continue
labelName = re.sub(r'[^a-z0-9]+', ' ', dirName.lower())
trainingImages = []
testingImages =[]
validationImages = []
for fileName in fileList:
baseName = os.path.basename(fileName)
hashName = re.sub(r'_nohash_.*$', '', fileName)
hashNameHased = hashlib.sha1(compat.as_bytes(hashName)).hexdigest()
percentHash = ((int(hashNameHased, 16) %
(MAX_NUM_IMAGES_PER_CLASS + 1)) *
(100.0 / MAX_NUM_IMAGES_PER_CLASS))
if percentHash < validationPercventage:
validationImages.append(baseName)
elif percentHash < (testingPercentage + validationPercventage):
testingImages.append(baseName)
else:
trainingImages.append(baseName)
result[labelName] = {
'dir': dirName,
'training': trainingImages,
'testing': testingImages,
'validation': validationImages,
}
return result
def createImageLists(imageDir, testingPercentage, validationPercventage):
if not gfile.Exists(imageDir):
print("Image dir'" + imageDir +"'not found.'")
return None
result = {}
subDirs = [x[0] for x in gfile.Walk(imageDir)]
isRootDir = True
for subDir in subDirs:
if isRootDir:
isRootDir = False
continue
extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
fileList = []
dirName = os.path.basename(subDir)
if dirName == imageDir:
continue
print("Looking for images in '" + dirName + "'")
for extension in extensions:
fileGlob = os.path.join(imageDir, dirName, '*.' + extension)
fileList.extend(gfile.Glob(fileGlob))
if not fileList:
print('No file found')
continue
labelName = re.sub(r'[^a-z0-9]+', ' ', dirName.lower())
trainingImages = []
testingImages =[]
validationImages = []
for fileName in fileList:
baseName = os.path.basename(fileName)
hashName = re.sub(r'_nohash_.*$', '', fileName)
hashNameHased = hashlib.sha1(compat.as_bytes(hashName)).hexdigest()
percentHash = ((int(hashNameHased, 16) %
(MAX_NUM_IMAGES_PER_CLASS + 1)) *
(100.0 / MAX_NUM_IMAGES_PER_CLASS))
if percentHash < validationPercventage:
validationImages.append(baseName)
elif percentHash < (testingPercentage + validationPercventage):
testingImages.append(baseName)
else:
trainingImages.append(baseName)
result[labelName] = {
'dir': dirName,
'training': trainingImages,
'testing': testingImages,
'validation': validationImages,
}
return result
def createImageLists(imageDir, testingPercentage, validationPercventage):
if not gfile.Exists(imageDir):
print("Image dir'" + imageDir +"'not found.'")
return None
result = {}
subDirs = [x[0] for x in gfile.Walk(imageDir)]
isRootDir = True
for subDir in subDirs:
if isRootDir:
isRootDir = False
continue
extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
fileList = []
dirName = os.path.basename(subDir)
if dirName == imageDir:
continue
print("Looking for images in '" + dirName + "'")
for extension in extensions:
fileGlob = os.path.join(imageDir, dirName, '*.' + extension)
fileList.extend(gfile.Glob(fileGlob))
if not fileList:
print('No file found')
continue
labelName = re.sub(r'[^a-z0-9]+', ' ', dirName.lower())
trainingImages = []
testingImages =[]
validationImages = []
for fileName in fileList:
baseName = os.path.basename(fileName)
hashName = re.sub(r'_nohash_.*$', '', fileName)
hashNameHased = hashlib.sha1(compat.as_bytes(hashName)).hexdigest()
percentHash = ((int(hashNameHased, 16) %
(MAX_NUM_IMAGES_PER_CLASS + 1)) *
(100.0 / MAX_NUM_IMAGES_PER_CLASS))
if percentHash < validationPercventage:
validationImages.append(baseName)
elif percentHash < (testingPercentage + validationPercventage):
testingImages.append(baseName)
else:
trainingImages.append(baseName)
result[labelName] = {
'dir': dirName,
'training': trainingImages,
'testing': testingImages,
'validation': validationImages,
}
return result
def createImageLists(imageDir, testingPercentage, validationPercventage):
if not gfile.Exists(imageDir):
print("Image dir'" + imageDir +"'not found.'")
return None
result = {}
subDirs = [x[0] for x in gfile.Walk(imageDir)]
isRootDir = True
for subDir in subDirs:
if isRootDir:
isRootDir = False
continue
extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
fileList = []
dirName = os.path.basename(subDir)
if dirName == imageDir:
continue
print("Looking for images in '" + dirName + "'")
for extension in extensions:
fileGlob = os.path.join(imageDir, dirName, '*.' + extension)
fileList.extend(gfile.Glob(fileGlob))
if not fileList:
print('No file found')
continue
labelName = re.sub(r'[^a-z0-9]+', ' ', dirName.lower())
trainingImages = []
testingImages =[]
validationImages = []
for fileName in fileList:
baseName = os.path.basename(fileName)
hashName = re.sub(r'_nohash_.*$', '', fileName)
hashNameHased = hashlib.sha1(compat.as_bytes(hashName)).hexdigest()
percentHash = ((int(hashNameHased, 16) %
(MAX_NUM_IMAGES_PER_CLASS + 1)) *
(100.0 / MAX_NUM_IMAGES_PER_CLASS))
if percentHash < validationPercventage:
validationImages.append(baseName)
elif percentHash < (testingPercentage + validationPercventage):
testingImages.append(baseName)
else:
trainingImages.append(baseName)
result[labelName] = {
'dir': dirName,
'training': trainingImages,
'testing': testingImages,
'validation': validationImages,
}
return result
def createImageLists(imageDir, testingPercentage, validationPercventage):
if not gfile.Exists(imageDir):
print("Image dir'" + imageDir +"'not found.'")
return None
result = {}
subDirs = [x[0] for x in gfile.Walk(imageDir)]
isRootDir = True
for subDir in subDirs:
if isRootDir:
isRootDir = False
continue
extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
fileList = []
dirName = os.path.basename(subDir)
if dirName == imageDir:
continue
print("Looking for images in '" + dirName + "'")
for extension in extensions:
fileGlob = os.path.join(imageDir, dirName, '*.' + extension)
fileList.extend(gfile.Glob(fileGlob))
if not fileList:
print('No file found')
continue
labelName = re.sub(r'[^a-z0-9]+', ' ', dirName.lower())
trainingImages = []
testingImages =[]
validationImages = []
for fileName in fileList:
baseName = os.path.basename(fileName)
hashName = re.sub(r'_nohash_.*$', '', fileName)
hashNameHased = hashlib.sha1(compat.as_bytes(hashName)).hexdigest()
percentHash = ((int(hashNameHased, 16) %
(MAX_NUM_IMAGES_PER_CLASS + 1)) *
(100.0 / MAX_NUM_IMAGES_PER_CLASS))
if percentHash < validationPercventage:
validationImages.append(baseName)
elif percentHash < (testingPercentage + validationPercventage):
testingImages.append(baseName)
else:
trainingImages.append(baseName)
result[labelName] = {
'dir': dirName,
'training': trainingImages,
'testing': testingImages,
'validation': validationImages,
}
return result
def _add_collection_def(meta_graph_def, key):
"""Adds a collection to MetaGraphDef protocol buffer.
Args:
meta_graph_def: MetaGraphDef protocol buffer.
key: One of the GraphKeys or user-defined string.
"""
if not isinstance(key, six.string_types) and not isinstance(key, bytes):
logging.warning("Only collections with string type keys will be "
"serialized. This key has %s" % type(key))
return
collection_list = ops.get_collection(key)
if not collection_list:
return
try:
col_def = meta_graph_def.collection_def[key]
to_proto = ops.get_to_proto_function(key)
proto_type = ops.get_collection_proto_type(key)
if to_proto:
kind = "bytes_list"
for x in collection_list:
# Additional type check to make sure the returned proto is indeed
# what we expect.
proto = to_proto(x)
assert isinstance(proto, proto_type)
getattr(col_def, kind).value.append(proto.SerializeToString())
else:
kind = _get_kind_name(collection_list[0])
if kind == "node_list":
getattr(col_def, kind).value.extend([x.name for x in collection_list])
elif kind == "bytes_list":
# NOTE(opensource): This force conversion is to work around the fact
# that Python3 distinguishes between bytes and strings.
getattr(col_def, kind).value.extend(
[compat.as_bytes(x) for x in collection_list])
else:
getattr(col_def, kind).value.extend([x for x in collection_list])
except Exception as e: # pylint: disable=broad-except
logging.warning("Error encountered when serializing %s.\n"
"Type is unsupported, or the types of the items don't "
"match field type in CollectionDef.\n%s" % (key, str(e)))
if key in meta_graph_def.collection_def:
del meta_graph_def.collection_def[key]
return