def __init__(self, sess, checkpoint_dir, log_dir, training_paths, testing_paths, roi, im_size, nclass,
batch_size=1, layers=3, features_root=32, conv_size=3, dropout=0.5, testing_gt_available=True,
loss_type='cross_entropy', class_weights=None):
self.sess = sess
self.checkpoint_dir = checkpoint_dir
self.log_dir = log_dir
self.training_paths = training_paths
self.testing_paths = testing_paths
self.testing_gt_available = testing_gt_available
self.nclass = nclass
self.im_size = im_size
self.roi = roi # (roi_order, roi_name)
self.batch_size = batch_size
self.layers = layers
self.features_root = features_root
self.conv_size = conv_size
self.dropout = dropout
self.loss_type = loss_type
self.class_weights = class_weights
self.build_model()
self.saver = tf.train.Saver(tf.trainable_variables() + tf.get_collection_ref('bn_collections'))
评论列表
文章目录