def initialize_entities(self, entities, max_entnum, train=True):
e2sD = {}
old2newD = {}
if train:
news = self.xp.random.randint(0, max_entnum, len(entities))
else:
news = entities
new_e_L = []
for new, entity in zip(news, entities):
old2newD[entity] = int(new)
new_e_L.append(new)
es_L = F.split_axis(
self.embed(chainer.Variable(self.xp.array(new_e_L, dtype=np.int32), volatile=not train)),
len(new_e_L), axis=0)
if len(new_e_L) <= 1:
es_L = [es_L]
for new_e, es in zip(new_e_L, es_L):
e2sD[new_e] = es
return old2newD, e2sD
评论列表
文章目录