npy2ckpt.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:tensorflow-deeplab-resnet 作者: DrSleep 项目源码 文件源码
def main():
    """Create the model and start the training."""
    args = get_arguments()

    # Default image.
    image_batch = tf.constant(0, tf.float32, shape=[1, 321, 321, 3]) 
    # Create network.
    net = DeepLabResNetModel({'data': image_batch})
    var_list = tf.global_variables()

    # Set up tf session and initialize variables. 
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    with tf.Session(config=config) as sess:
          init = tf.global_variables_initializer()
          sess.run(init)

          # Loading .npy weights.
          net.load(args.npy_path, sess)

          # Saver for converting the loaded weights into .ckpt.
          saver = tf.train.Saver(var_list=var_list, write_version=1)
          save(saver, sess, args.save_dir)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号