def _create_image_encoder(preprocess_fn, factory_fn, image_shape, batch_size=32,
session=None, checkpoint_path=None,
loss_mode="cosine"):
image_var = tf.placeholder(tf.uint8, (None, ) + image_shape)
preprocessed_image_var = tf.map_fn(
lambda x: preprocess_fn(x, is_training=False),
tf.cast(image_var, tf.float32))
l2_normalize = loss_mode == "cosine"
feature_var, _ = factory_fn(
preprocessed_image_var, l2_normalize=l2_normalize, reuse=None)
feature_dim = feature_var.get_shape().as_list()[-1]
if session is None:
session = tf.Session()
if checkpoint_path is not None:
slim.get_or_create_global_step()
init_assign_op, init_feed_dict = slim.assign_from_checkpoint(
checkpoint_path, slim.get_variables_to_restore())
session.run(init_assign_op, feed_dict=init_feed_dict)
def encoder(data_x):
out = np.zeros((len(data_x), feature_dim), np.float32)
_run_in_batches(
lambda x: session.run(feature_var, feed_dict=x),
{image_var: data_x}, out, batch_size)
return out
return encoder