def _load_model_from_config(config_path, hparam_overrides, vocab_file, mode):
"""Loads model from a configuration file"""
with gfile.GFile(config_path) as config_file:
config = yaml.load(config_file)
model_cls = locate(config["model"]) or getattr(models, config["model"])
model_params = config["model_params"]
if hparam_overrides:
model_params.update(hparam_overrides)
# Change the max decode length to make the test run faster
model_params["decoder.params"]["max_decode_length"] = 5
model_params["vocab_source"] = vocab_file
model_params["vocab_target"] = vocab_file
return model_cls(params=model_params, mode=mode)
评论列表
文章目录