setup_data.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:Twitter100k 作者: huyt16 项目源码 文件源码
def EditTrainers(data_dir, model_dir, rep_dir, numsplits):
  tnames = ['train_CD_image_layer1.pbtxt',
            'train_CD_image_layer2.pbtxt',
            'train_CD_text_layer1.pbtxt',
            'train_CD_text_layer2.pbtxt',
            'train_CD_joint_layer.pbtxt']
  for tname in tnames:
    t_op_file = os.path.join('trainers', 'dbn', tname)
    t_op = util.ReadOperation(t_op_file)
    if 'layer1' in tname:
      t_op.data_proto_prefix = data_dir
    else:
      t_op.data_proto_prefix = rep_dir
    t_op.checkpoint_directory = model_dir
    with open(t_op_file, 'w') as f:
      text_format.PrintMessage(t_op, f)

  t_op_file = os.path.join('trainers', 'classifiers', 'baseclassifier.pbtxt')
  t_op = util.ReadOperation(t_op_file)
  for i in range(1, numsplits+1):
    t_op_file = os.path.join('trainers', 'classifiers', 'split_%d.pbtxt' % i)
    t_op.data_proto_prefix = rep_dir
    t_op.data_proto = os.path.join('split_%d' % i, 'data.pbtxt')
    t_op.checkpoint_prefix = model_dir
    t_op.checkpoint_directory = os.path.join('classifiers','split_%d' % i)  
    with open(t_op_file, 'w') as f:
      text_format.PrintMessage(t_op, f)

  # Change prefix in multimodal dbn model
  mnames = ['multimodal_dbn.pbtxt']
  for mname in mnames:
    model_file = os.path.join('models', mname)
    model = util.ReadModel(model_file)
    model.prefix = model_dir
    with open(model_file, 'w') as f:
      text_format.PrintMessage(model, f)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号