squeeze_model.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:sact 作者: mfigurnov 项目源码 文件源码
def main(_):
  if not tf.gfile.Exists(FLAGS.output_dir):
    tf.gfile.MakeDirs(FLAGS.output_dir)

  assert FLAGS.model is not None
  assert FLAGS.model_type in ('vanilla', 'act', 'act_early_stopping', 'sact')
  assert FLAGS.dataset in ('imagenet', 'cifar')

  batch_size = 1

  if FLAGS.dataset == 'imagenet':
    height, width = 224, 224
    num_classes = 1001
  elif FLAGS.dataset == 'cifar':
    height, width = 32, 32
    num_classes = 10

  images = tf.random_uniform((batch_size, height, width, 3))
  model = utils.split_and_int(FLAGS.model)

  # Define the model
  if FLAGS.dataset == 'imagenet':
    with slim.arg_scope(imagenet_model.resnet_arg_scope(is_training=False)):
      logits, end_points = imagenet_model.get_network(
          images,
          model,
          num_classes,
          model_type=FLAGS.model_type)
  elif FLAGS.dataset == 'cifar':
    # Define the model:
    with slim.arg_scope(cifar_model.resnet_arg_scope(is_training=False)):
      logits, end_points = cifar_model.resnet(
          images,
          model=model,
          num_classes=num_classes,
          model_type=FLAGS.model_type)

  tf_global_step = slim.get_or_create_global_step()

  checkpoint_path = tf.train.latest_checkpoint(FLAGS.input_dir)
  assert checkpoint_path is not None

  saver = tf.train.Saver(write_version=2)

  with tf.Session() as sess:
    saver.restore(sess, checkpoint_path)
    saver.save(sess, FLAGS.output_dir + '/model', global_step=tf_global_step)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号