save_variables.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:seq2seq_parser 作者: trangham283 项目源码 文件源码
def save_vars(filename):
  """ Decode file sentence-by-sentence  """
  with tf.Session(config=tf.ConfigProto(intra_op_parallelism_threads=NUM_THREADS)) as sess:
    # Create model and load parameters.
    with tf.variable_scope("model", reuse=None):
      model_dev, steps_done = create_model_default(sess, forward_only=True, dropout=False, model_path=model_path)


    var_dict = {}
    for var in tf.all_variables():
      print(var.name, var.get_shape())
      if 'Adagrad' in var.name: continue
      var_dict[var.name] = var.eval()

    pickle.dump(var_dict, open(filename, 'w'))

    #for v in tf.all_variables():
    #  print(v.name, v.get_shape())
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号