def getpairs(model, batch, params):
g1 = []
g2 = []
for i in batch:
g1.append(i[0].embeddings)
g2.append(i[1].embeddings)
g1x, g1mask = utils.prepare_data(g1)
g2x, g2mask = utils.prepare_data(g2)
embg1 = model.feedforward_function(g1x, g1mask)
embg2 = model.feedforward_function(g2x, g2mask)
for idx, i in enumerate(batch):
i[0].representation = embg1[idx, :]
i[1].representation = embg2[idx, :]
pairs = getPairsFast(batch, params.type)
p1 = []
p2 = []
for i in pairs:
p1.append(i[0].embeddings)
p2.append(i[1].embeddings)
p1x, p1mask = utils.prepare_data(p1)
p2x, p2mask = utils.prepare_data(p2)
return (g1x, g1mask, g2x, g2mask, p1x, p1mask, p2x, p2mask)
ppdb_utils.py 文件源码
python
阅读 28
收藏 0
点赞 0
评论 0
评论列表
文章目录