cifar_model_test.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:sact 作者: mfigurnov 项目源码 文件源码
def _runBatch(self, is_training, model_type, model=[2]):
    batch_size = 2
    height, width = 32, 32
    num_classes = 10

    with slim.arg_scope(
        cifar_model.resnet_arg_scope(is_training=is_training)):
      with self.test_session() as sess:
        images = tf.random_uniform((batch_size, height, width, 3))
        logits, end_points = cifar_model.resnet(
            images,
            model=model,
            num_classes=num_classes,
            model_type=model_type,
            base_channels=1)
        if model_type in ('act', 'act_early_stopping', 'sact'):
          metrics = summary_utils.act_metric_map(end_points,
              not is_training)
          metrics.update(summary_utils.flops_metric_map(end_points,
              not is_training))
        else:
          metrics = {}

        if is_training:
          labels = tf.random_uniform(
              (batch_size,), maxval=num_classes, dtype=tf.int32)
          one_hot_labels = slim.one_hot_encoding(labels, num_classes)
          tf.losses.softmax_cross_entropy(
              onehot_labels=one_hot_labels, logits=logits)
          if model_type in ('act', 'act_early_stopping', 'sact'):
            training_utils.add_all_ponder_costs(end_points, weights=1.0)
          total_loss = tf.losses.get_total_loss()
          optimizer = tf.train.MomentumOptimizer(0.1, 0.9)
          train_op = slim.learning.create_train_op(total_loss, optimizer)
          sess.run(tf.global_variables_initializer())
          sess.run((train_op, metrics))
        else:
          sess.run([tf.local_variables_initializer(),
              tf.global_variables_initializer()])
          logits_out, metrics_out = sess.run((logits, metrics))
          self.assertEqual(logits_out.shape, (batch_size, num_classes))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号