model.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:teras 作者: chantera 项目源码 文件源码
def __init__(self, *args, dropout=0.0):
        embeds = []
        self.size = 0
        for i, _args in enumerate(args):
            if isinstance(_args, dict):
                vocab_size = _args.get('in_size', None)
                embed_size = _args.get('out_size', None)
                embeddings = _args.get('initialW', None)
                if vocab_size is None or embed_size is None:
                    if embeddings is None:
                        raise ValueError('embeddings or in_size/out_size '
                                         'must be specified')
                    vocab_size, embed_size = embeddings.shape
                    _args['in_size'] = vocab_size
                    _args['out_size'] = embed_size
            else:
                if isinstance(_args, np.ndarray):
                    vocab_size, embed_size = _args.shape
                    embeddings = _args
                elif isinstance(_args, tuple) and len(embeddings) == 2:
                    vocab_size, embed_size = _args
                    embeddings = None
                else:
                    raise ValueError('embeddings must be '
                                     'np.ndarray or tuple(len=2)')
                _args = {'in_size': vocab_size, 'out_size': embed_size,
                         'initialW': embeddings}
            embeds.append(EmbedID(**_args))
            self.size += embed_size
        super(Embed, self).__init__(*embeds)

        assert dropout == 0 or type(dropout) == float
        self._dropout_ratio = dropout
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号