def forward_one_step(self, X, ht_enc):
pad = self._kernel_size - 1
WX = self.W(X)[..., -pad-1, None]
Vh = self.V(ht_enc)
Vh, WX = functions.broadcast(functions.expand_dims(Vh, axis=2), WX)
return self.pool(functions.split_axis(WX + Vh, self.num_split, axis=1))
评论列表
文章目录