def _runBatch(self,
is_training,
model_type,
model=[2, 2, 2, 2]):
batch_size = 2
height, width = 128, 128
num_classes = 10
with self.test_session() as sess:
images = tf.random_uniform((batch_size, height, width, 3))
with slim.arg_scope(
imagenet_model.resnet_arg_scope(is_training=is_training)):
logits, end_points = imagenet_model.get_network(
images, model, num_classes, model_type='sact', base_channels=1)
if model_type in ('act', 'act_early_stopping', 'sact'):
metrics = summary_utils.act_metric_map(end_points,
not is_training)
metrics.update(summary_utils.flops_metric_map(end_points,
not is_training))
else:
metrics = {}
if is_training:
labels = tf.random_uniform(
(batch_size,), maxval=num_classes, dtype=tf.int32)
one_hot_labels = slim.one_hot_encoding(labels, num_classes)
tf.losses.softmax_cross_entropy(
onehot_labels=one_hot_labels, logits=logits,
label_smoothing=0.1, weights=1.0)
if model_type in ('act', 'act_early_stopping', 'sact'):
training_utils.add_all_ponder_costs(end_points, weights=1.0)
total_loss = tf.losses.get_total_loss()
optimizer = tf.train.MomentumOptimizer(0.1, 0.9)
train_op = slim.learning.create_train_op(total_loss, optimizer)
sess.run(tf.global_variables_initializer())
sess.run((train_op, metrics))
else:
sess.run([tf.local_variables_initializer(),
tf.global_variables_initializer()])
logits_out, metrics_out = sess.run((logits, metrics))
self.assertEqual(logits_out.shape, (batch_size, num_classes))
评论列表
文章目录