def __init__(self, sess, checkpoint_dir, log_dir, training_paths, testing_paths,
batch_size=1, layers=3, features_root=32, conv_size=3, dropout=0.5,
loss_type='cross_entropy', class_weights=None):
self.sess = sess
self.checkpoint_dir = checkpoint_dir
self.log_dir = log_dir
self.training_paths = training_paths
self.testing_paths = testing_paths
image, _ = read_patch(os.path.join(self.training_paths[0], '0'))
self.nclass = 4
self.batch_size = batch_size
self.patch_size = image.shape[:-1]
self.patch_stride = 4 # Used in deploy
self.channel = image.shape[-1]
self.layers = layers
self.features_root = features_root
self.conv_size = conv_size
self.dropout = dropout
self.loss_type = loss_type
self.class_weights = class_weights
self.patches_per_image = len(os.listdir(self.training_paths[0]))
self.build_model()
self.saver = tf.train.Saver(tf.trainable_variables() + tf.get_collection_ref('bn_collections'))
评论列表
文章目录