def change_model(proto, layers=None):
model = util.ReadModel(proto)
if layers is None:
layers = ['image_hidden1', 'image_hidden2', 'image_hidden3',
'text_hidden1', 'text_hidden2', 'text_hidden3',
'image_layer', 'text_layer', 'joint_layer',
'image_tied_hidden', 'text_tied_hidden',
'image_hidden2_recon', 'text_hidden2_recon',
'cross_image_hidden2_recon', 'cross_text_hidden2_recon']
for layer in layers:
try:
layer_proto = next(lay for lay in model.layer if lay.name == layer)
layer_proto.dimensions = dimensions
except StopIteration:
pass
with open(proto, 'w') as f:
text_format.PrintMessage(model, f)
评论列表
文章目录