def forward(self, x, fast=False, unitTest=False):
start = time.time()
#Run gate
gates, expertInds = self.gate(x)
#Run experts
if unitTest:
vanilla, _ = self.vanillaExperts(x, gates, expertInds)
fast, _ = self.fastExperts(x, gates, expertInds)
return t.abs(vanilla - fast)
elif fast:
ret, cellTime = self.fastExperts(x, gates, expertInds)
else:
ret, cellTime = self.vanillaExperts(x, gates, expertInds)
forwardTime = time.time() - start
return ret, forwardTime, cellTime
评论列表
文章目录