token_container.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:basis 作者: vaitech 项目源码 文件源码
def to_embedding(self, vector_dim=None, learn_difference=False, name=None,
                     embeddings_initializer='he_normal'):
        from keras.layers import Embedding

        W = None

        if self.W is not None:
            W = np.zeros((
                self.size + len(self._special_tokens),
                self.W.shape[1]
            ))

            W[len(self._special_tokens):, :] = self.W
            W = [W]
            vector_dim = self.W.shape[1]
        else:
            if vector_dim is None:
                ValueError('If container has no matrix W defined, vector '
                           'dimension for embedding must be explicitly '
                           'specified.')

        emb = Embedding(
            input_dim=self.size + len(self._special_tokens),
            output_dim=vector_dim,
            weights=W,
            mask_zero=True,
            name=name,
            embeddings_initializer=embeddings_initializer
        )

        if learn_difference:

            if W is None:
                logger.warning('Learning a difference on top of non-pretrained '
                               'word vectors is not recommended')

            from keras.models import Model
            from keras.initializers import RandomUniform
            from keras.layers import Input, add

            emb.trainable = False
            delta_initializer = RandomUniform(minval=-0.005, maxval=0.005)

            if name is None:
                name = emb.name

            delta = Embedding(
                input_dim=self.size + len(self._special_tokens),
                output_dim=vector_dim,
                embeddings_initializer=delta_initializer,
                mask_zero=True,
                name=name + '/delta_correction'
            )

            x = Input((None, ), dtype='int32', name=name + '/input')

            e = add([emb(x), delta(x)], name=name + '/addition')

            emb = Model(x, e, name='shifted_emb/' + name)

        return emb
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号