def main(_):
if not tf.gfile.Exists(FLAGS.output_dir):
tf.gfile.MakeDirs(FLAGS.output_dir)
assert FLAGS.model is not None
assert FLAGS.model_type in ('vanilla', 'act', 'act_early_stopping', 'sact')
assert FLAGS.dataset in ('imagenet', 'cifar')
batch_size = 1
if FLAGS.dataset == 'imagenet':
height, width = 224, 224
num_classes = 1001
elif FLAGS.dataset == 'cifar':
height, width = 32, 32
num_classes = 10
images = tf.random_uniform((batch_size, height, width, 3))
model = utils.split_and_int(FLAGS.model)
# Define the model
if FLAGS.dataset == 'imagenet':
with slim.arg_scope(imagenet_model.resnet_arg_scope(is_training=False)):
logits, end_points = imagenet_model.get_network(
images,
model,
num_classes,
model_type=FLAGS.model_type)
elif FLAGS.dataset == 'cifar':
# Define the model:
with slim.arg_scope(cifar_model.resnet_arg_scope(is_training=False)):
logits, end_points = cifar_model.resnet(
images,
model=model,
num_classes=num_classes,
model_type=FLAGS.model_type)
tf_global_step = slim.get_or_create_global_step()
checkpoint_path = tf.train.latest_checkpoint(FLAGS.input_dir)
assert checkpoint_path is not None
saver = tf.train.Saver(write_version=2)
with tf.Session() as sess:
saver.restore(sess, checkpoint_path)
saver.save(sess, FLAGS.output_dir + '/model', global_step=tf_global_step)
评论列表
文章目录