def EditTrainers(data_dir, model_dir, rep_dir, numsplits):
tnames = ['train_CD_image_layer1.pbtxt',
'train_CD_image_layer2.pbtxt',
'train_CD_text_layer1.pbtxt',
'train_CD_text_layer2.pbtxt',
'train_CD_joint_layer.pbtxt']
for tname in tnames:
t_op_file = os.path.join('trainers', 'dbn', tname)
t_op = util.ReadOperation(t_op_file)
if 'layer1' in tname:
t_op.data_proto_prefix = data_dir
else:
t_op.data_proto_prefix = rep_dir
t_op.checkpoint_directory = model_dir
with open(t_op_file, 'w') as f:
text_format.PrintMessage(t_op, f)
t_op_file = os.path.join('trainers', 'classifiers', 'baseclassifier.pbtxt')
t_op = util.ReadOperation(t_op_file)
for i in range(1, numsplits+1):
t_op_file = os.path.join('trainers', 'classifiers', 'split_%d.pbtxt' % i)
t_op.data_proto_prefix = rep_dir
t_op.data_proto = os.path.join('split_%d' % i, 'data.pbtxt')
t_op.checkpoint_prefix = model_dir
t_op.checkpoint_directory = os.path.join('classifiers','split_%d' % i)
with open(t_op_file, 'w') as f:
text_format.PrintMessage(t_op, f)
# Change prefix in multimodal dbn model
mnames = ['multimodal_dbn.pbtxt']
for mname in mnames:
model_file = os.path.join('models', mname)
model = util.ReadModel(model_file)
model.prefix = model_dir
with open(model_file, 'w') as f:
text_format.PrintMessage(model, f)
评论列表
文章目录