def read(self, h):
#M_key = F.swapaxes(F.stack(self.key_buff, axis=0), axis1=0, axis2=1) # (B, M, m)
M_key = F.stack(self.key_buff, axis=1) # (B, M, m)
self.p = F.softmax(F.reshape(F.batch_matmul(M_key, h, transa=False, transb=False), (h.shape[0], M_key.shape[1]))) # (B, M)
#p = F.reshape(p, (h.shape[0], 1, M_key.shape[1])) # (B, 1, M)
#print("p", p.shape)
#M_val = F.swapaxes(F.stack(self.val_buff, axis=0), axis1=0, axis2=1) # (B, M, m)
M_val = F.stack(self.val_buff, axis=1) # (B, M, m)
#print("M_val", M_val.shape)
o = F.batch_matmul(self.p, M_val, transa=True, transb=False) # (B, 1, m)
o = F.reshape(o, (o.shape[0], o.shape[2])) # (B, m)
#print("o", o.shape)
return o, self.p
评论列表
文章目录