model.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:brats17 作者: xf4j 项目源码 文件源码
def __init__(self, sess, checkpoint_dir, log_dir, training_paths, testing_paths,
                 batch_size=1, layers=3, features_root=32, conv_size=3, dropout=0.5,
                 loss_type='cross_entropy', class_weights=None):
        self.sess = sess

        self.checkpoint_dir = checkpoint_dir
        self.log_dir = log_dir

        self.training_paths = training_paths
        self.testing_paths = testing_paths

        image, _ = read_patch(os.path.join(self.training_paths[0], '0'))

        self.nclass = 4
        self.batch_size = batch_size
        self.patch_size = image.shape[:-1]
        self.patch_stride = 4 # Used in deploy
        self.channel = image.shape[-1]
        self.layers = layers
        self.features_root = features_root
        self.conv_size = conv_size
        self.dropout = dropout
        self.loss_type = loss_type
        self.class_weights = class_weights
        self.patches_per_image = len(os.listdir(self.training_paths[0]))

        self.build_model()

        self.saver = tf.train.Saver(tf.trainable_variables() + tf.get_collection_ref('bn_collections'))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号