def forward(self, hidden_vb, memory_vb):
# outputs for computing addressing for heads
# NOTE: to be consistent w/ the dnc paper, we use
# NOTE: sigmoid to constrain to [0, 1]
# NOTE: oneplus to constrain to [1, +inf]
self.key_vb = F.tanh(self.hid_2_key(hidden_vb)).view(-1, self.num_heads, self.mem_wid) # TODO: relu to bias the memory to store positive values ??? check again
self.beta_vb = F.softplus(self.hid_2_beta(hidden_vb)).view(-1, self.num_heads, 1) # beta >=1: https://github.com/deepmind/dnc/issues/9
self.gate_vb = F.sigmoid(self.hid_2_gate(hidden_vb)).view(-1, self.num_heads, 1) # gate /in (0, 1): interpolation gate, blend wl_{t-1} & wc
self.shift_vb = F.softmax(self.hid_2_shift(hidden_vb).view(-1, self.num_heads, self.num_allowed_shifts).transpose(0, 2)).transpose(0, 2) # shift: /sum=1
self.gamma_vb = (1. + F.softplus(self.hid_2_gamma(hidden_vb))).view(-1, self.num_heads, 1) # gamma >= 1: sharpen the final weights
# now we compute the addressing mechanism
self._content_focus(memory_vb)
self._location_focus()
评论列表
文章目录