def stylize(args):
content_image = utils.load_image(args.content_image, scale=args.content_scale)
content_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: x.mul(255))
])
content_image = content_transform(content_image)
content_image = content_image.unsqueeze(0)
if args.cuda:
content_image = content_image.cuda()
content_image = Variable(content_image, volatile=True)
style_model = TransformerNet()
style_model.load_state_dict(torch.load(args.model))
if args.cuda:
style_model.cuda()
output = style_model(content_image)
if args.cuda:
output = output.cpu()
output_data = output.data[0]
utils.save_image(args.output_image, output_data)
python类load_image()的实例源码
def eval_accuracy_loss(X_data, y_data, BATCH_SIZE, top1_accuracy,top5_accuracy,loss_operation,images,y,RC,train_mode,regConst):
nImgs = len(X_data)
total_top1 = 0.0
total_top5 = 0.0
total_crossEn = 0.0
sess = tf.get_default_session()
for offset in range(0, nImgs, BATCH_SIZE):
batch_x = utils.load_image(X_data[offset:offset+BATCH_SIZE])
batch_y = y_data[offset:offset+BATCH_SIZE]
t1,t5,cEn = sess.run([top1_accuracy,top5_accuracy,loss_operation],
feed_dict={images:batch_x, y:batch_y, RC: regConst, KP:1, train_mode:False})
total_top1 += t1
total_top5 += t5
total_crossEn += (cEn *len(batch_x))
total_top1 /= nImgs
total_top5 /= nImgs
total_crossEn /= nImgs
return total_top1, total_top5, total_crossEn
def main():
parser = argparse.ArgumentParser()
parser.add_argument('output', help='output director')
args = parser.parse_args()
output = Path(args.output)
output.mkdir(exist_ok=True)
poly_stats = {}
for im_id in sorted(utils.get_wkt_data()):
print(im_id)
im_data = utils.load_image(im_id, rgb_only=True)
im_data = utils.scale_percentile(im_data)
cv2.imwrite(str(output.joinpath('{}.jpg'.format(im_id))), 255 * im_data)
im_size = im_data.shape[:2]
poly_by_type = utils.load_polygons(im_id, im_size)
for poly_type, poly in sorted(poly_by_type.items()):
cls = poly_type - 1
mask = utils.mask_for_polygons(im_size, poly)
cv2.imwrite(
str(output.joinpath('{}_mask_{}.png'.format(im_id, cls))),
255 * mask)
poly_stats.setdefault(im_id, {})[cls] = {
'area': poly.area / (im_size[0] * im_size[1]),
'perimeter': int(poly.length),
'number': len(poly),
}
output.joinpath('stats.json').write_text(json.dumps(poly_stats))
for key in ['number', 'perimeter', 'area']:
if key == 'area':
fmt = '{:.4%}'.format
else:
fmt = lambda x: x
print('\n{}'.format(key))
print(tabulate.tabulate(
[[im_id] + [fmt(s[cls][key]) for cls in range(10)]
for im_id, s in sorted(poly_stats.items())],
headers=['im_id'] + list(range(10))))
def load_image(self, im_id: str) -> Image:
logger.info('Loading {}'.format(im_id))
im_cache = Path('im_cache')
im_cache.mkdir(exist_ok=True)
im_data_path = im_cache.joinpath('{}.data'.format(im_id))
mask_path = im_cache.joinpath('{}.mask'.format(im_id))
if im_data_path.exists():
im_data = np.load(str(im_data_path))
else:
im_data = self.preprocess_image(utils.load_image(im_id))
with im_data_path.open('wb') as f:
np.save(f, im_data)
pre_buffer = self.hps.pre_buffer
if mask_path.exists() and not pre_buffer:
mask = np.load(str(mask_path))
else:
im_size = im_data.shape[1:]
poly_by_type = utils.load_polygons(im_id, im_size)
if pre_buffer:
structures = 2
poly_by_type[structures] = utils.to_multipolygon(
poly_by_type[structures].buffer(pre_buffer))
mask = np.array(
[utils.mask_for_polygons(im_size, poly_by_type[cls + 1])
for cls in range(self.hps.total_classes)],
dtype=np.uint8)
if not pre_buffer:
with mask_path.open('wb') as f:
np.save(f, mask)
if self.hps.n_channels != im_data.shape[0]:
im_data = im_data[:self.hps.n_channels]
return Image(im_id, im_data, mask[self.hps.classes])
def predict_masks(args, hps, store, to_predict: List[str], threshold: float,
validation: str=None, no_edges: bool=False):
logger.info('Predicting {} masks: {}'
.format(len(to_predict), ', '.join(sorted(to_predict))))
model = Model(hps=hps)
if args.model_path:
model.restore_snapshot(args.model_path)
else:
model.restore_last_snapshot(args.logdir)
def load_im(im_id):
data = model.preprocess_image(utils.load_image(im_id))
if hps.n_channels != data.shape[0]:
data = data[:hps.n_channels]
if validation == 'square':
data = square(data, hps)
return Image(id=im_id, data=data)
def predict_mask(im):
logger.info(im.id)
return im, model.predict_image_mask(im.data, no_edges=no_edges)
im_masks = map(predict_mask, utils.imap_fixed_output_buffer(
load_im, sorted(to_predict), threads=2))
for im, mask in utils.imap_fixed_output_buffer(
lambda _: next(im_masks), to_predict, threads=1):
assert mask.shape[1:] == im.data.shape[1:]
with gzip.open(str(mask_path(store, im.id)), 'wb') as f:
# TODO - maybe do (mask * 20).astype(np.uint8)
np.save(f, mask >= threshold)
def load_batches(self):
""" Load max_batches batches into memory """
is_new = False
if self.batches:
batches = self.batches
else:
is_new = True
batches = []
self.batches = batches
image_dir = self.image_dir
batch_size = self.batch_size
image_h = self.image_h
image_w = self.image_w
file_list = os.listdir(image_dir) if image_dir != DUMMY else []
n = self.last_load
for b in range(self.max_batches):
if is_new:
arr = np.zeros((batch_size, image_h, image_w, 3))
batches.append(arr)
else:
arr = batches[b]
if image_dir != DUMMY:
i = 0
while i < batch_size:
file_name = file_list[n]
try:
image = utils.load_image(os.path.join(image_dir,file_name), image_h, image_w)
arr[i] = image
i += 1
except:
pass
n += 1 if not self.valid else -1
self.last_load = n
def fineTuneNet(X_train,y_train,BATCH_SIZE,images,y,LR,RC,train_mode,learnRate,regConst,mask,layList,KP,keepProb,indMask):
nLayers = len(mask)
print(["Function fineTune",nLayers])
sess = tf.get_default_session()
n_train = len(y_train)
X_train, y_train = shuffle(X_train, y_train)
for offset in range(0, n_train, BATCH_SIZE):
end = offset + BATCH_SIZE
batch_x = utils.load_image(X_train[offset:end])
batch_y = y_train[offset:end]
#print("in FT:",chkWts(layList,indMask,layType))
if (nLayers==1):
# drop outs applied only for fully connected
rat = float(np.prod(mask[0].shape)-len(indMask[0][0]))/float(np.prod(mask[0].shape))
doAdj = keepProb*np.sqrt(rat)
sess.run(applygrad0,feed_dict={Mask0:mask[0],
images: batch_x, y: batch_y, LR:learnRate, RC: regConst, KP:doAdj, train_mode:True})
elif (nLayers==2):
sess.run(applygrad1,feed_dict={Mask1:mask[1],
images: batch_x, y: batch_y, LR:learnRate, RC: regConst, KP:keepProb, train_mode:True})
sess.run(applygrad0,feed_dict={Mask0:mask[0],
images: batch_x, y: batch_y, LR:learnRate, RC: regConst, KP:keepProb,train_mode:True})
elif (nLayers==4):
sess.run(applygrad3,feed_dict={Mask3:mask[3],
images: batch_x, y: batch_y, LR:learnRate, RC: regConst, KP:keepProb,train_mode:True})
sess.run(applygrad2,feed_dict={Mask2:mask[2],
images: batch_x, y: batch_y, LR:learnRate, RC: regConst, KP:keepProb,train_mode:True})
sess.run(applygrad1,feed_dict={Mask1:mask[1],
images: batch_x, y: batch_y, LR:learnRate, RC: regConst, KP:keepProb,train_mode:True})
sess.run(applygrad0,feed_dict={Mask0:mask[0],
images: batch_x, y: batch_y, LR:learnRate, RC: regConst, KP:keepProb,train_mode:True})
else:
print("wrong number of layers passed")
break
def train(self, logdir: Path, train_ids: List[str], valid_ids: List[str],
validation: str, no_mp: bool=False, valid_only: bool=False,
model_path: Path=None):
self.tb_logger = tensorboard_logger.Logger(str(logdir))
self.logdir = logdir
train_images = [self.load_image(im_id) for im_id in sorted(train_ids)]
valid_images = None
if model_path:
self.restore_snapshot(model_path)
start_epoch = int(model_path.name.rsplit('-', 1)[1]) + 1
else:
start_epoch = self.restore_last_snapshot(logdir)
square_validation = validation == 'square'
lr = self.hps.lr
self.optimizer = self._init_optimizer(lr)
for n_epoch in range(start_epoch, self.hps.n_epochs):
if self.hps.lr_decay:
if n_epoch % 2 == 0 or n_epoch == start_epoch:
lr = self.hps.lr * self.hps.lr_decay ** n_epoch
self.optimizer = self._init_optimizer(lr)
else:
lim_1, lim_2 = 25, 50
if n_epoch == lim_1 or (
n_epoch == start_epoch and n_epoch > lim_1):
lr = self.hps.lr / 5
self.optimizer = self._init_optimizer(lr)
if n_epoch == lim_2 or (
n_epoch == start_epoch and n_epoch > lim_2):
lr = self.hps.lr / 25
self.optimizer = self._init_optimizer(lr)
logger.info('Starting epoch {}, step {:,}, lr {:.8f}'.format(
n_epoch + 1, self.net.global_step[0], lr))
subsample = 1 if valid_only else 2 # make validation more often
for _ in range(subsample):
if not valid_only:
self.train_on_images(
train_images,
subsample=subsample,
square_validation=square_validation,
no_mp=no_mp)
if valid_images is None:
if square_validation:
s = self.hps.validation_square
valid_images = [
Image(None, im.data[:, :s, :s], im.mask[:, :s, :s])
for im in train_images]
else:
valid_images = [self.load_image(im_id)
for im_id in sorted(valid_ids)]
if valid_images:
self.validate_on_images(valid_images, subsample=1)
if valid_only:
break
self.save_snapshot(n_epoch)
self.tb_logger = None
self.logdir = None
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--model_file', type=str, default='data/vg-30.pb',
help='Pretrained model file to run')
parser.add_argument('--input', type=str,
default='data/sf.jpg',
help='Input image to process')
parser.add_argument('--output', type=str, default="output.png",
help='Output image file')
args = parser.parse_args()
logging.basicConfig(stream=sys.stdout,
format='%(asctime)s %(levelname)s:%(message)s',
level=logging.INFO,
datefmt='%I:%M:%S')
with open(args.model_file, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def)
graph = tf.get_default_graph()
with tf.Session(config=tf.ConfigProto(intra_op_parallelism_threads=4)) as session:
graph_info = session.graph
logging.info("Initializing graph")
session.run(tf.initialize_all_variables())
model_name = os.path.split(args.model_file)[-1][:-3]
image = graph.get_tensor_by_name("import/%s/image_in:0" % model_name)
out = graph.get_tensor_by_name("import/%s/output:0" % model_name)
shape = image.get_shape().as_list()
target = [utils.load_image(args.input, image_h=shape[1], image_w=shape[2])]
logging.info("Processing image")
start_time = datetime.now()
processed = session.run(out, feed_dict={image: target})
logging.info("Processing took %f" % ((datetime.now()-start_time).total_seconds()))
utils.write_image(args.output, processed)
logging.info("Done")