def main(unused_argv):
logging.set_verbosity(tf.logging.INFO)
# convert feature_names and feature_sizes to lists of values
feature_names, feature_sizes = utils.GetListOfFeatureNamesAndSizes(
FLAGS.feature_names, FLAGS.feature_sizes)
if FLAGS.frame_features:
reader = readers.YT8MFrameFeatureReader(feature_names=feature_names,
feature_sizes=feature_sizes)
else:
reader = readers.YT8MAggregatedFeatureReader(feature_names=feature_names,
feature_sizes=feature_sizes)
if FLAGS.output_file is "":
raise ValueError("'output_file' was not specified. "
"Unable to continue with inference.")
if FLAGS.input_data_pattern is "":
raise ValueError("'input_data_pattern' was not specified. "
"Unable to continue with inference.")
model = find_class_by_name(FLAGS.model,
[frame_level_models, video_level_models])()
transformer_fn = find_class_by_name(FLAGS.feature_transformer,
[feature_transform])
build_graph(reader,
model,
input_data_pattern=FLAGS.input_data_pattern,
batch_size=FLAGS.batch_size,
transformer_class=transformer_fn)
saver = tf.train.Saver(max_to_keep=3, keep_checkpoint_every_n_hours=10000000000)
inference(saver, FLAGS.train_dir,
FLAGS.output_file, FLAGS.batch_size, FLAGS.top_k)
评论列表
文章目录