def _runBatch(self, is_training, model_type, model=[2]):
batch_size = 2
height, width = 32, 32
num_classes = 10
with slim.arg_scope(
cifar_model.resnet_arg_scope(is_training=is_training)):
with self.test_session() as sess:
images = tf.random_uniform((batch_size, height, width, 3))
logits, end_points = cifar_model.resnet(
images,
model=model,
num_classes=num_classes,
model_type=model_type,
base_channels=1)
if model_type in ('act', 'act_early_stopping', 'sact'):
metrics = summary_utils.act_metric_map(end_points,
not is_training)
metrics.update(summary_utils.flops_metric_map(end_points,
not is_training))
else:
metrics = {}
if is_training:
labels = tf.random_uniform(
(batch_size,), maxval=num_classes, dtype=tf.int32)
one_hot_labels = slim.one_hot_encoding(labels, num_classes)
tf.losses.softmax_cross_entropy(
onehot_labels=one_hot_labels, logits=logits)
if model_type in ('act', 'act_early_stopping', 'sact'):
training_utils.add_all_ponder_costs(end_points, weights=1.0)
total_loss = tf.losses.get_total_loss()
optimizer = tf.train.MomentumOptimizer(0.1, 0.9)
train_op = slim.learning.create_train_op(total_loss, optimizer)
sess.run(tf.global_variables_initializer())
sess.run((train_op, metrics))
else:
sess.run([tf.local_variables_initializer(),
tf.global_variables_initializer()])
logits_out, metrics_out = sess.run((logits, metrics))
self.assertEqual(logits_out.shape, (batch_size, num_classes))
评论列表
文章目录