def __init__(self,
input_size_per_time_step: int,
allowed_characters: List[chr],
use_raw_wave_input: bool = False,
activation: str = "relu",
output_activation: str = "softmax",
optimizer: Optimizer = Adam(1e-4),
dropout: Optional[float] = None,
load_model_from_directory: Optional[Path] = None,
load_epoch: Optional[int] = None,
allowed_characters_for_loaded_model: Optional[List[chr]] = None,
frozen_layer_count: int = 0,
reinitialize_trainable_loaded_layers: bool = False,
use_asg: bool = False,
asg_transition_probabilities: Optional[ndarray] = None,
asg_initial_probabilities: Optional[ndarray] = None,
kenlm_directory: Path = None):
if frozen_layer_count > 0 and load_model_from_directory is None:
raise ValueError("Layers cannot be frozen if model is trained from scratch.")
self.kenlm_directory = kenlm_directory
self.grapheme_encoding = AsgGraphemeEncoding(allowed_characters=allowed_characters) \
if use_asg else CtcGraphemeEncoding(allowed_characters=allowed_characters)
self.asg_transition_probabilities = self._default_asg_transition_probabilities(
self.grapheme_encoding.grapheme_set_size) \
if asg_transition_probabilities is None else asg_transition_probabilities
self.asg_initial_probabilities = self._default_asg_initial_probabilities(
self.grapheme_encoding.grapheme_set_size) \
if asg_initial_probabilities is None else asg_initial_probabilities
self.use_asg = use_asg
self.frozen_layer_count = frozen_layer_count
self.output_activation = output_activation
self.activation = activation
self.use_raw_wave_input = use_raw_wave_input
self.input_size_per_time_step = input_size_per_time_step
self.optimizer = optimizer
self.load_epoch = load_epoch
self.dropout = dropout
self.predictive_net = self.create_predictive_net()
self.prediction_phase_flag = 0.
if self.kenlm_directory is not None:
expected_characters = list(
single(read_text(self.kenlm_directory / "vocabulary", encoding='utf8').splitlines()).lower())
if allowed_characters != expected_characters:
raise ValueError("Allowed characters {} differ from those expected by kenlm decoder: {}".
format(allowed_characters, expected_characters))
if load_model_from_directory is not None:
self.load_weights(
allowed_characters_for_loaded_model, load_epoch, load_model_from_directory,
loaded_first_layers_count=frozen_layer_count if reinitialize_trainable_loaded_layers else None)
评论列表
文章目录