imagenet_model_test.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:sact 作者: mfigurnov 项目源码 文件源码
def _runBatch(self,
                is_training,
                model_type,
                model=[2, 2, 2, 2]):
    batch_size = 2
    height, width = 128, 128
    num_classes = 10

    with self.test_session() as sess:
      images = tf.random_uniform((batch_size, height, width, 3))
      with slim.arg_scope(
          imagenet_model.resnet_arg_scope(is_training=is_training)):
        logits, end_points = imagenet_model.get_network(
            images, model, num_classes, model_type='sact', base_channels=1)
        if model_type in ('act', 'act_early_stopping', 'sact'):
          metrics = summary_utils.act_metric_map(end_points,
              not is_training)
          metrics.update(summary_utils.flops_metric_map(end_points,
              not is_training))
        else:
          metrics = {}

      if is_training:
        labels = tf.random_uniform(
            (batch_size,), maxval=num_classes, dtype=tf.int32)
        one_hot_labels = slim.one_hot_encoding(labels, num_classes)
        tf.losses.softmax_cross_entropy(
            onehot_labels=one_hot_labels, logits=logits,
            label_smoothing=0.1, weights=1.0)
        if model_type in ('act', 'act_early_stopping', 'sact'):
          training_utils.add_all_ponder_costs(end_points, weights=1.0)
        total_loss = tf.losses.get_total_loss()
        optimizer = tf.train.MomentumOptimizer(0.1, 0.9)
        train_op = slim.learning.create_train_op(total_loss, optimizer)
        sess.run(tf.global_variables_initializer())
        sess.run((train_op, metrics))
      else:
        sess.run([tf.local_variables_initializer(),
            tf.global_variables_initializer()])
        logits_out, metrics_out = sess.run((logits, metrics))
        self.assertEqual(logits_out.shape, (batch_size, num_classes))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号