models.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:keras-image-captioning 作者: danieljl 项目源码 文件源码
def __init__(self,
                 learning_rate=None,
                 vocab_size=None,
                 embedding_size=None,
                 rnn_output_size=None,
                 dropout_rate=None,
                 bidirectional_rnn=None,
                 rnn_type=None,
                 rnn_layers=None,
                 l1_reg=None,
                 l2_reg=None,
                 initializer=None,
                 word_vector_init=None):
        """
        If an arg is None, it will get its value from config.active_config.
        """
        self._learning_rate = learning_rate or active_config().learning_rate
        self._vocab_size = vocab_size or active_config().vocab_size
        self._embedding_size = embedding_size or active_config().embedding_size
        self._rnn_output_size = (rnn_output_size or
                                 active_config().rnn_output_size)
        self._dropout_rate = dropout_rate or active_config().dropout_rate
        self._rnn_type = rnn_type or active_config().rnn_type
        self._rnn_layers = rnn_layers or active_config().rnn_layers
        self._word_vector_init = (word_vector_init or
                                  active_config().word_vector_init)

        self._initializer = initializer or active_config().initializer
        if self._initializer == 'vinyals_uniform':
            self._initializer = RandomUniform(-0.08, 0.08)

        if bidirectional_rnn is None:
            self._bidirectional_rnn = active_config().bidirectional_rnn
        else:
            self._bidirectional_rnn = bidirectional_rnn

        l1_reg = l1_reg or active_config().l1_reg
        l2_reg = l2_reg or active_config().l2_reg
        self._regularizer = l1_l2(l1_reg, l2_reg)

        self._keras_model = None

        if self._vocab_size is None:
            raise ValueError('config.active_config().vocab_size cannot be '
                             'None! You should check your config or you can '
                             'explicitly pass the vocab_size argument.')

        if self._rnn_type not in ('lstm', 'gru'):
            raise ValueError('rnn_type must be either "lstm" or "gru"!')

        if self._rnn_layers < 1:
            raise ValueError('rnn_layers must be >= 1!')

        if self._word_vector_init is not None and self._embedding_size != 300:
            raise ValueError('If word_vector_init is not None, embedding_size '
                             'must be 300')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号