model_me.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:crnn_tf 作者: liuhu-bigeye 项目源码 文件源码
def assign_from_pkl(self, pkl_path):
    with open(pkl_path, 'rb') as f:
      load_variables = pickle.load(f)

    uninitialized_vars = []
    for i, variable in enumerate(tf.global_variables()):
      # 0 -41
      # 42-77 + 10
      # 78-117+ 20
      if i<=41:
        idx = i
      elif i<=77:
        idx = i + 10
      elif i<=117:
        idx = i + 20
      else:
        uninitialized_vars.append(variable)
        continue

      variable_shape = load_variables[idx].shape
      if len(variable_shape) == 1:
        load_variable = load_variables[idx]
      elif len(variable_shape) == 4:
        load_variable = np.transpose(load_variables[idx], [3, 2, 1, 0])
      elif len(variable_shape) == 3:
        load_variable = np.transpose(load_variables[idx], [2, 1, 0])
      else:
        assert False

      print variable.name, variable.get_shape(), load_variable.shape
      variable.assign(load_variable).op.run()

    pdb.set_trace()
    tf.initialize_variables(uninitialized_vars).op.run()
    return
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号