net.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:speechless 作者: JuliusKunze 项目源码 文件源码
def __init__(self,
                 input_size_per_time_step: int,
                 allowed_characters: List[chr],
                 use_raw_wave_input: bool = False,
                 activation: str = "relu",
                 output_activation: str = "softmax",
                 optimizer: Optimizer = Adam(1e-4),
                 dropout: Optional[float] = None,
                 load_model_from_directory: Optional[Path] = None,
                 load_epoch: Optional[int] = None,
                 allowed_characters_for_loaded_model: Optional[List[chr]] = None,
                 frozen_layer_count: int = 0,
                 reinitialize_trainable_loaded_layers: bool = False,
                 use_asg: bool = False,
                 asg_transition_probabilities: Optional[ndarray] = None,
                 asg_initial_probabilities: Optional[ndarray] = None,
                 kenlm_directory: Path = None):

        if frozen_layer_count > 0 and load_model_from_directory is None:
            raise ValueError("Layers cannot be frozen if model is trained from scratch.")

        self.kenlm_directory = kenlm_directory
        self.grapheme_encoding = AsgGraphemeEncoding(allowed_characters=allowed_characters) \
            if use_asg else CtcGraphemeEncoding(allowed_characters=allowed_characters)

        self.asg_transition_probabilities = self._default_asg_transition_probabilities(
            self.grapheme_encoding.grapheme_set_size) \
            if asg_transition_probabilities is None else asg_transition_probabilities

        self.asg_initial_probabilities = self._default_asg_initial_probabilities(
            self.grapheme_encoding.grapheme_set_size) \
            if asg_initial_probabilities is None else asg_initial_probabilities

        self.use_asg = use_asg
        self.frozen_layer_count = frozen_layer_count
        self.output_activation = output_activation
        self.activation = activation
        self.use_raw_wave_input = use_raw_wave_input
        self.input_size_per_time_step = input_size_per_time_step
        self.optimizer = optimizer
        self.load_epoch = load_epoch
        self.dropout = dropout
        self.predictive_net = self.create_predictive_net()
        self.prediction_phase_flag = 0.

        if self.kenlm_directory is not None:
            expected_characters = list(
                single(read_text(self.kenlm_directory / "vocabulary", encoding='utf8').splitlines()).lower())

            if allowed_characters != expected_characters:
                raise ValueError("Allowed characters {} differ from those expected by kenlm decoder: {}".
                                 format(allowed_characters, expected_characters))

        if load_model_from_directory is not None:
            self.load_weights(
                allowed_characters_for_loaded_model, load_epoch, load_model_from_directory,
                loaded_first_layers_count=frozen_layer_count if reinitialize_trainable_loaded_layers else None)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号