def __call__(self, X, ht_enc):
pad = self._kernel_size - 1
WX = self.W(X)
if pad > 0:
WX = WX[..., :-pad]
Vh = self.V(ht_enc)
# copy Vh
# e.g.
# WX = [[[ 0 1 2]
# [ 3 4 5]
# [ 6 7 8]
# Vh = [[11, 12, 13]]
#
# Vh, WX = F.broadcast(F.expand_dims(Vh, axis=2), WX)
#
# WX = [[[ 0 1 2]
# [ 3 4 5]
# [ 6 7 8]
# Vh = [[[ 11 11 11]
# [ 12 12 12]
# [ 13 13 13]
Vh, WX = functions.broadcast(functions.expand_dims(Vh, axis=2), WX)
return self.pool(functions.split_axis(WX + Vh, self.num_split, axis=1))
评论列表
文章目录