def make_testable(train_model_path):
# load the train net prototxt as a protobuf message
with open(train_model_path) as f:
train_str = f.read()
train_net = caffe_pb2.NetParameter()
text_format.Merge(train_str, train_net)
# add the mean, var top blobs to all BN layers
for layer in train_net.layer:
if layer.type == "BN" and len(layer.top) == 1:
layer.top.append(layer.top[0] + "-mean")
layer.top.append(layer.top[0] + "-var")
# remove the test data layer if present
if train_net.layer[1].name == "data" and train_net.layer[1].include:
train_net.layer.remove(train_net.layer[1])
if train_net.layer[0].include:
# remove the 'include {phase: TRAIN}' layer param
train_net.layer[0].include.remove(train_net.layer[0].include[0])
return train_net
python类Merge()的实例源码
def import_protobuf(self, pb_file, verbose=False):
"""
Imports graph_def from protobuf file to ngraph.
Arguments:
pb_file: Protobuf file path.
verbose: Prints graph_def at each node if True.
"""
# read graph_def
graph_def = tf.GraphDef()
if mimetypes.guess_type(pb_file)[0] == 'text/plain':
with open(pb_file, 'r') as f:
text_format.Merge(f.read(), graph_def)
else:
with open(pb_file, 'rb') as f:
graph_def.ParseFromString(f.read())
self.import_graph_def(graph_def, verbose=verbose)
def _assert_encode_decode(self, coder, expected_proto_text, expected_decoded):
example = tf.train.Example()
text_format.Merge(expected_proto_text, example)
data = example.SerializeToString()
# Assert the data is decoded into the expected format.
decoded = coder.decode(data)
np.testing.assert_equal(expected_decoded, decoded)
# Assert the decoded data can be encoded back into the original proto.
encoded = coder.encode(decoded)
parsed_example = tf.train.Example()
parsed_example.ParseFromString(encoded)
self.assertEqual(example, parsed_example)
# Assert the data can be decoded from the encoded string.
decoded_again = coder.decode(encoded)
np.testing.assert_equal(expected_decoded, decoded_again)
def _assert_decode_encode(self, coder, expected_proto_text, expected_decoded):
example = tf.train.Example()
text_format.Merge(expected_proto_text, example)
# Assert the expected decoded data can be encoded into the expected proto.
encoded = coder.encode(expected_decoded)
parsed_example = tf.train.Example()
parsed_example.ParseFromString(encoded)
self.assertEqual(example, parsed_example)
# Assert the encoded data can be decoded into the original input.
decoded = coder.decode(encoded)
np.testing.assert_equal(expected_decoded, decoded)
# Assert the decoded data can be encoded back into the expected proto.
encoded_again = coder.encode(decoded)
parsed_example_again = tf.train.Example()
parsed_example_again.ParseFromString(encoded_again)
np.testing.assert_equal(example, parsed_example_again)
def test_example_proto_coder_error(self):
input_schema = dataset_schema.from_feature_spec({
'2d_vector_feature': tf.FixedLenFeature(shape=[2, 2], dtype=tf.int64),
})
coder = example_proto_coder.ExampleProtoCoder(input_schema)
example_decoded_value = {
'2d_vector_feature': [1, 2, 3]
}
example_proto_text = """
features {
feature { key: "1d_vector_feature"
value { int64_list { value: [ 1, 2, 3 ] } } }
}
"""
example = tf.train.Example()
text_format.Merge(example_proto_text, example)
# Ensure that we raise an exception for trying to encode invalid data.
with self.assertRaisesRegexp(ValueError, 'got wrong number of values'):
_ = coder.encode(example_decoded_value)
# Ensure that we raise an exception for trying to parse invalid data.
with self.assertRaisesRegexp(ValueError, 'got wrong number of values'):
_ = coder.decode(example.SerializeToString())
text_format_test.py 文件源码
项目:Vector-Tiles-Reader-QGIS-Plugin
作者: geometalab
项目源码
文件源码
阅读 19
收藏 0
点赞 0
评论 0
def testMergeExpandedAnyRepeated(self):
message = any_test_pb2.TestAny()
text = ('repeated_any_value {\n'
' [type.googleapis.com/protobuf_unittest.OneString] {\n'
' data: "string0"\n'
' }\n'
'}\n'
'repeated_any_value {\n'
' [type.googleapis.com/protobuf_unittest.OneString] {\n'
' data: "string1"\n'
' }\n'
'}\n')
text_format.Merge(text, message, descriptor_pool=descriptor_pool.Default())
packed_message = unittest_pb2.OneString()
message.repeated_any_value[0].Unpack(packed_message)
self.assertEqual('string0', packed_message.data)
message.repeated_any_value[1].Unpack(packed_message)
self.assertEqual('string1', packed_message.data)
reflection_test.py 文件源码
项目:Vector-Tiles-Reader-QGIS-Plugin
作者: geometalab
项目源码
文件源码
阅读 27
收藏 0
点赞 0
评论 0
def testParsingNestedClass(self):
"""Test that the generated class can parse a nested message."""
file_descriptor = descriptor_pb2.FileDescriptorProto()
file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('C'))
msg_descriptor = descriptor.MakeDescriptor(
file_descriptor.message_type[0])
msg_class = reflection.MakeClass(msg_descriptor)
msg = msg_class()
msg_str = (
'bar {'
' baz {'
' deep: 4'
' }'
'}')
text_format.Merge(msg_str, msg)
self.assertEqual(msg.bar.baz.deep, 4)
def make_testable(train_model_path):
# load the train net prototxt as a protobuf message
with open(train_model_path) as f:
train_str = f.read()
train_net = caffe_pb2.NetParameter()
text_format.Merge(train_str, train_net)
# add the mean, var top blobs to all BN layers
for layer in train_net.layer:
if layer.type == "BN" and len(layer.top) == 1:
layer.top.append(layer.top[0] + "-mean")
layer.top.append(layer.top[0] + "-var")
# remove the test data layer if present
if train_net.layer[1].name == "data" and train_net.layer[1].include:
train_net.layer.remove(train_net.layer[1])
if train_net.layer[0].include:
# remove the 'include {phase: TRAIN}' layer param
train_net.layer[0].include.remove(train_net.layer[0].include[0])
return train_net
def make_testable(train_model_path):
# load the train net prototxt as a protobuf message
with open(train_model_path) as f:
train_str = f.read()
train_net = caffe_pb2.NetParameter()
text_format.Merge(train_str, train_net)
# add the mean, var top blobs to all BN layers
for layer in train_net.layer:
if layer.type == "BN" and len(layer.top) == 1:
layer.top.append(layer.top[0] + "-mean")
layer.top.append(layer.top[0] + "-var")
# remove the test data layer if present
if train_net.layer[1].name == "data" and train_net.layer[1].include:
train_net.layer.remove(train_net.layer[1])
if train_net.layer[0].include:
# remove the 'include {phase: TRAIN}' layer param
train_net.layer[0].include.remove(train_net.layer[0].include[0])
return train_net
def testParsingNestedClass(self):
"""Test that the generated class can parse a nested message."""
file_descriptor = descriptor_pb2.FileDescriptorProto()
file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('C'))
msg_descriptor = descriptor.MakeDescriptor(
file_descriptor.message_type[0])
msg_class = reflection.MakeClass(msg_descriptor)
msg = msg_class()
msg_str = (
'bar {'
' baz {'
' deep: 4'
' }'
'}')
text_format.Merge(msg_str, msg)
self.assertEqual(msg.bar.baz.deep, 4)
def testRoundTripExoticAsOneLine(self):
message = unittest_pb2.TestAllTypes()
message.repeated_int64.append(-9223372036854775808)
message.repeated_uint64.append(18446744073709551615)
message.repeated_double.append(123.456)
message.repeated_double.append(1.23e22)
message.repeated_double.append(1.23e-18)
message.repeated_string.append('\000\001\a\b\f\n\r\t\v\\\'"')
message.repeated_string.append(u'\u00fc\ua71f')
# Test as_utf8 = False.
wire_text = text_format.MessageToString(
message, as_one_line=True, as_utf8=False)
parsed_message = unittest_pb2.TestAllTypes()
text_format.Merge(wire_text, parsed_message)
self.assertEquals(message, parsed_message)
# Test as_utf8 = True.
wire_text = text_format.MessageToString(
message, as_one_line=True, as_utf8=True)
parsed_message = unittest_pb2.TestAllTypes()
text_format.Merge(wire_text, parsed_message)
self.assertEquals(message, parsed_message)
def testMergeMessageSet(self):
message = unittest_pb2.TestAllTypes()
text = ('repeated_uint64: 1\n'
'repeated_uint64: 2\n')
text_format.Merge(text, message)
self.assertEqual(1, message.repeated_uint64[0])
self.assertEqual(2, message.repeated_uint64[1])
message = unittest_mset_pb2.TestMessageSetContainer()
text = ('message_set {\n'
' [protobuf_unittest.TestMessageSetExtension1] {\n'
' i: 23\n'
' }\n'
' [protobuf_unittest.TestMessageSetExtension2] {\n'
' str: \"foo\"\n'
' }\n'
'}\n')
text_format.Merge(text, message)
ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension
ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension
self.assertEquals(23, message.message_set.Extensions[ext1].i)
self.assertEquals('foo', message.message_set.Extensions[ext2].str)
def testMergeExotic(self):
message = unittest_pb2.TestAllTypes()
text = ('repeated_int64: -9223372036854775808\n'
'repeated_uint64: 18446744073709551615\n'
'repeated_double: 123.456\n'
'repeated_double: 1.23e+22\n'
'repeated_double: 1.23e-18\n'
'repeated_string: \n'
'"\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\""\n'
'repeated_string: "foo" \'corge\' "grault"\n'
'repeated_string: "\\303\\274\\352\\234\\237"\n'
'repeated_string: "\\xc3\\xbc"\n'
'repeated_string: "\xc3\xbc"\n')
text_format.Merge(text, message)
self.assertEqual(-9223372036854775808, message.repeated_int64[0])
self.assertEqual(18446744073709551615, message.repeated_uint64[0])
self.assertEqual(123.456, message.repeated_double[0])
self.assertEqual(1.23e22, message.repeated_double[1])
self.assertEqual(1.23e-18, message.repeated_double[2])
self.assertEqual(
'\000\001\a\b\f\n\r\t\v\\\'"', message.repeated_string[0])
self.assertEqual('foocorgegrault', message.repeated_string[1])
self.assertEqual(u'\u00fc\ua71f', message.repeated_string[2])
self.assertEqual(u'\u00fc', message.repeated_string[3])
def testMergeBadEnumValue(self):
message = unittest_pb2.TestAllTypes()
text = 'optional_nested_enum: BARR'
self.assertRaisesWithMessage(
text_format.ParseError,
('1:23 : Enum type "protobuf_unittest.TestAllTypes.NestedEnum" '
'has no value named BARR.'),
text_format.Merge, text, message)
message = unittest_pb2.TestAllTypes()
text = 'optional_nested_enum: 100'
self.assertRaisesWithMessage(
text_format.ParseError,
('1:23 : Enum type "protobuf_unittest.TestAllTypes.NestedEnum" '
'has no value with number 100.'),
text_format.Merge, text, message)
def testMergeStringFieldUnescape(self):
message = unittest_pb2.TestAllTypes()
text = r'''repeated_string: "\xf\x62"
repeated_string: "\\xf\\x62"
repeated_string: "\\\xf\\\x62"
repeated_string: "\\\\xf\\\\x62"
repeated_string: "\\\\\xf\\\\\x62"
repeated_string: "\x5cx20"'''
text_format.Merge(text, message)
SLASH = '\\'
self.assertEqual('\x0fb', message.repeated_string[0])
self.assertEqual(SLASH + 'xf' + SLASH + 'x62', message.repeated_string[1])
self.assertEqual(SLASH + '\x0f' + SLASH + 'b', message.repeated_string[2])
self.assertEqual(SLASH + SLASH + 'xf' + SLASH + SLASH + 'x62',
message.repeated_string[3])
self.assertEqual(SLASH + SLASH + '\x0f' + SLASH + SLASH + 'b',
message.repeated_string[4])
self.assertEqual(SLASH + 'x20', message.repeated_string[5])
def testMergeExpandedAny(self):
message = any_test_pb2.TestAny()
text = ('any_value {\n'
' [type.googleapis.com/protobuf_unittest.OneString] {\n'
' data: "string"\n'
' }\n'
'}\n')
text_format.Merge(text, message, descriptor_pool=descriptor_pool.Default())
packed_message = unittest_pb2.OneString()
message.any_value.Unpack(packed_message)
self.assertEqual('string', packed_message.data)
message.Clear()
text_format.Parse(text, message, descriptor_pool=descriptor_pool.Default())
packed_message = unittest_pb2.OneString()
message.any_value.Unpack(packed_message)
self.assertEqual('string', packed_message.data)
def testMergeExpandedAnyRepeated(self):
message = any_test_pb2.TestAny()
text = ('repeated_any_value {\n'
' [type.googleapis.com/protobuf_unittest.OneString] {\n'
' data: "string0"\n'
' }\n'
'}\n'
'repeated_any_value {\n'
' [type.googleapis.com/protobuf_unittest.OneString] {\n'
' data: "string1"\n'
' }\n'
'}\n')
text_format.Merge(text, message, descriptor_pool=descriptor_pool.Default())
packed_message = unittest_pb2.OneString()
message.repeated_any_value[0].Unpack(packed_message)
self.assertEqual('string0', packed_message.data)
message.repeated_any_value[1].Unpack(packed_message)
self.assertEqual('string1', packed_message.data)
def make_testable(train_model_path):
# load the train net prototxt as a protobuf message
with open(train_model_path) as f:
train_str = f.read()
train_net = caffe_pb2.NetParameter()
text_format.Merge(train_str, train_net)
# add the mean, var top blobs to all BN layers
for layer in train_net.layer:
if layer.type == "BN" and len(layer.top) == 1:
layer.top.append(layer.top[0] + "-mean")
layer.top.append(layer.top[0] + "-var")
# remove the test data layer if present
if train_net.layer[1].name == "data" and train_net.layer[1].include:
train_net.layer.remove(train_net.layer[1])
if train_net.layer[0].include:
# remove the 'include {phase: TRAIN}' layer param
train_net.layer[0].include.remove(train_net.layer[0].include[0])
return train_net
def add_params(self,params):
"""
Set or update solver parameters
"""
paramstr = ''
for key, val in params.items():
self.sp.ClearField(key) #reset field
if isinstance(val,str): #if val is a string
paramstr += (key + ': ' + '"' + val + '"' + '\n')
elif type(val) is list: #repeatable field
for it in val:
paramstr += (key + ': ' + str(it) + '\n')
elif type(val) == type(True): #boolean type
if val:
paramstr += (key + ': true\n')
else:
paramstr += (key + ': false\n')
else: #numerical value
paramstr += (key + ': ' + str(val) + '\n')
#apply change
text_format.Merge(paramstr, self.sp)
def __init__(self):
"""Loading DNN model."""
model_path = '/home/jiri/caffe/models/bvlc_googlenet/'
net_fn = model_path + 'deploy.prototxt'
param_fn = model_path + 'bvlc_googlenet.caffemodel'
#model_path = '/home/jiri/caffe/models/oxford102/'
#net_fn = model_path + 'deploy.prototxt'
#param_fn = model_path + 'oxford102.caffemodel'
# Patching model to be able to compute gradients.
# Note that you can also manually add "force_backward: true" line
#to "deploy.prototxt".
model = caffe.io.caffe_pb2.NetParameter()
text_format.Merge(open(net_fn).read(), model)
model.force_backward = True
open('tmp.prototxt', 'w').write(str(model))
# ImageNet mean, training set dependent
mean = np.float32([104.0, 116.0, 122.0])
# the reference model has channels in BGR order instead of RGB
chann_sw = (2,1,0)
self.net = caffe.Classifier('tmp.prototxt', param_fn, mean=mean, channel_swap=chann_sw)
def make_testable(train_model_path):
# load the train net prototxt as a protobuf message
print "hello"
with open(train_model_path) as f:
train_str = f.read()
train_net = caffe_pb2.NetParameter()
text_format.Merge(train_str, train_net)
# add the mean, var top blobs to all BN layers
for layer in train_net.layer:
print(len(layer.top))
if layer.type == "BN" and len(layer.top) == 1:
layer.top.append(layer.top[0] + "-mean")
layer.top.append(layer.top[0] + "-var")
# remove the test data layer if present
if train_net.layer[1].name == "data" and train_net.layer[1].include:
train_net.layer.remove(train_net.layer[1])
if train_net.layer[0].include:
# remove the 'include {phase: TRAIN}' layer param
train_net.layer[0].include.remove(train_net.layer[0].include[0])
return train_net
def make_testable(train_model_path):
# load the train net prototxt as a protobuf message
with open(train_model_path) as f:
train_str = f.read()
train_net = caffe_pb2.NetParameter()
text_format.Merge(train_str, train_net)
# add the mean, var top blobs to all BN layers
for layer in train_net.layer:
#print layer.type
#print type(layer.top)
if layer.type == "BN" and len(layer.top) == 1:
layer.top.append(layer.top[0] + "-mean")
layer.top.append(layer.top[0] + "-var")
# remove the test data layer if present
if train_net.layer[1].name == "data" and train_net.layer[1].include:
train_net.layer.remove(train_net.layer[1])
if train_net.layer[0].include:
# remove the 'include {phase: TRAIN}' layer param
train_net.layer[0].include.remove(train_net.layer[0].include[0])
return train_net
def _latest_checkpoints_changed(configs, run_path_pairs):
"""Returns true if the latest checkpoint has changed in any of the runs."""
for run_name, assets_dir in run_path_pairs:
if run_name not in configs:
config = ProjectorConfig()
config_fpath = os.path.join(assets_dir, PROJECTOR_FILENAME)
if tf.gfile.Exists(config_fpath):
with tf.gfile.GFile(config_fpath, 'r') as f:
file_content = f.read()
text_format.Merge(file_content, config)
else:
config = configs[run_name]
# See if you can find a checkpoint file in the logdir.
logdir = _assets_dir_to_logdir(assets_dir)
ckpt_path = _find_latest_checkpoint(logdir)
if not ckpt_path:
continue
if config.model_checkpoint_path != ckpt_path:
return True
return False
def dump_data(logdir):
"""Dumps plugin data to the log directory."""
plugin_logdir = plugin_asset_util.PluginDirectory(
logdir, profile_plugin.ProfilePlugin.plugin_name)
_maybe_create_directory(plugin_logdir)
for run in profile_demo_data.RUNS:
run_dir = os.path.join(plugin_logdir, run)
_maybe_create_directory(run_dir)
if run in profile_demo_data.TRACES:
with open(os.path.join(run_dir, 'trace'), 'w') as f:
proto = trace_events_pb2.Trace()
text_format.Merge(profile_demo_data.TRACES[run], proto)
f.write(proto.SerializeToString())
shutil.copyfile('tensorboard/plugins/profile/profile_demo.op_profile.json',
os.path.join(run_dir, 'op_profile.json'))
# Unsupported tool data should not be displayed.
run_dir = os.path.join(plugin_logdir, 'empty')
_maybe_create_directory(run_dir)
with open(os.path.join(run_dir, 'unsupported'), 'w') as f:
f.write('unsupported data')
def __init__(self):
# load MS COCO labels
labelmap_file = os.path.join(CAFFE_ROOT, LABEL_MAP)
file = open(labelmap_file, 'r')
self._labelmap = caffe_pb2.LabelMap()
text_format.Merge(str(file.read()), self._labelmap)
model_def = os.path.join(CAFFE_ROOT, PROTO_TXT)
model_weights = os.path.join(CAFFE_ROOT, CAFFE_MODEL)
self._net = caffe.Net(model_def, model_weights, caffe.TEST)
self._transformer = caffe.io.Transformer(
{'data': self._net.blobs['data'].data.shape})
self._transformer.set_transpose('data', (2, 0, 1))
self._transformer.set_mean('data', np.array([104, 117, 123]))
self._transformer.set_raw_scale('data', 255)
self._transformer.set_channel_swap('data', (2, 1, 0))
# set net to batch size of 1
image_resize = IMAGE_SIZE
self._net.blobs['data'].reshape(1, 3, image_resize, image_resize)
def make_testable(train_model_path):
# load the train net prototxt as a protobuf message
with open(train_model_path) as f:
train_str = f.read()
train_net = caffe_pb2.NetParameter()
text_format.Merge(train_str, train_net)
# add the mean, var top blobs to all BN layers
for layer in train_net.layer:
if layer.type == "BN" and len(layer.top) == 1:
layer.top.append(layer.top[0] + "-mean")
layer.top.append(layer.top[0] + "-var")
# remove the test data layer if present
if train_net.layer[1].name == "data" and train_net.layer[1].include:
train_net.layer.remove(train_net.layer[1])
if train_net.layer[0].include:
# remove the 'include {phase: TRAIN}' layer param
train_net.layer[0].include.remove(train_net.layer[0].include[0])
return train_net
def make_testable(train_model_path):
# load the train net prototxt as a protobuf message
with open(train_model_path) as f:
train_str = f.read()
train_net = caffe_pb2.NetParameter()
text_format.Merge(train_str, train_net)
# add the mean, var top blobs to all BN layers
for layer in train_net.layer:
if layer.type == "BN" and len(layer.top) == 1:
layer.top.append(layer.top[0] + "-mean")
layer.top.append(layer.top[0] + "-var")
# remove the test data layer if present
if train_net.layer[1].name == "data" and train_net.layer[1].include:
train_net.layer.remove(train_net.layer[1])
if train_net.layer[0].include:
# remove the 'include {phase: TRAIN}' layer param
train_net.layer[0].include.remove(train_net.layer[0].include[0])
return train_net
def testMergeExpandedAnyRepeated(self):
message = any_test_pb2.TestAny()
text = ('repeated_any_value {\n'
' [type.googleapis.com/protobuf_unittest.OneString] {\n'
' data: "string0"\n'
' }\n'
'}\n'
'repeated_any_value {\n'
' [type.googleapis.com/protobuf_unittest.OneString] {\n'
' data: "string1"\n'
' }\n'
'}\n')
text_format.Merge(text, message, descriptor_pool=descriptor_pool.Default())
packed_message = unittest_pb2.OneString()
message.repeated_any_value[0].Unpack(packed_message)
self.assertEqual('string0', packed_message.data)
message.repeated_any_value[1].Unpack(packed_message)
self.assertEqual('string1', packed_message.data)
def testParsingNestedClass(self):
"""Test that the generated class can parse a nested message."""
file_descriptor = descriptor_pb2.FileDescriptorProto()
file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('C'))
msg_descriptor = descriptor.MakeDescriptor(
file_descriptor.message_type[0])
msg_class = reflection.MakeClass(msg_descriptor)
msg = msg_class()
msg_str = (
'bar {'
' baz {'
' deep: 4'
' }'
'}')
text_format.Merge(msg_str, msg)
self.assertEqual(msg.bar.baz.deep, 4)
def testParsingNestedClass(self):
"""Test that the generated class can parse a nested message."""
file_descriptor = descriptor_pb2.FileDescriptorProto()
file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('C'))
msg_descriptor = descriptor.MakeDescriptor(
file_descriptor.message_type[0])
msg_class = reflection.MakeClass(msg_descriptor)
msg = msg_class()
msg_str = (
'bar {'
' baz {'
' deep: 4'
' }'
'}')
text_format.Merge(msg_str, msg)
self.assertEqual(msg.bar.baz.deep, 4)