def main():
"""Create the model and start the training."""
args = get_arguments()
# Default image.
image_batch = tf.constant(0, tf.float32, shape=[1, 321, 321, 3])
# Create network.
net = DeepLabResNetModel({'data': image_batch})
var_list = tf.global_variables()
# Set up tf session and initialize variables.
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
init = tf.global_variables_initializer()
sess.run(init)
# Loading .npy weights.
net.load(args.npy_path, sess)
# Saver for converting the loaded weights into .ckpt.
saver = tf.train.Saver(var_list=var_list, write_version=1)
save(saver, sess, args.save_dir)
评论列表
文章目录