dern.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:der-network 作者: soskek 项目源码 文件源码
def initialize_entities(self, entities, max_entnum, train=True):
        e2sD = {}
        old2newD = {}

        if train:
            news = self.xp.random.randint(0, max_entnum, len(entities))
        else:
            news = entities

        new_e_L = []
        for new, entity in zip(news, entities):
            old2newD[entity] = int(new)
            new_e_L.append(new)

        es_L = F.split_axis(
            self.embed(chainer.Variable(self.xp.array(new_e_L, dtype=np.int32), volatile=not train)),
            len(new_e_L), axis=0)
        if len(new_e_L) <= 1:
            es_L = [es_L]
        for new_e, es in zip(new_e_L, es_L):
            e2sD[new_e] = es

        return old2newD, e2sD
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号