def step(self, x, states):
ytm, stm = states
# repeat the hidden state to the length of the sequence
_stm = K.repeat(stm, self.timesteps)
# now multiplty the weight matrix with the repeated hidden state
_Wxstm = K.dot(_stm, self.W_a)
# calculate the attention probabilities
# this relates how much other timesteps contributed to this one.
et = K.dot(activations.tanh(_Wxstm + self._uxpb),
K.expand_dims(self.V_a))
at = K.exp(et)
at_sum = K.sum(at, axis=1)
at_sum_repeated = K.repeat(at_sum, self.timesteps)
at /= at_sum_repeated # vector of size (batchsize, timesteps, 1)
# calculate the context vector
context = K.squeeze(K.batch_dot(at, self.x_seq, axes=1), axis=1)
# ~~~> calculate new hidden state
# first calculate the "r" gate:
rt = activations.sigmoid(
K.dot(ytm, self.W_r)
+ K.dot(stm, self.U_r)
+ K.dot(context, self.C_r)
+ self.b_r)
# now calculate the "z" gate
zt = activations.sigmoid(
K.dot(ytm, self.W_z)
+ K.dot(stm, self.U_z)
+ K.dot(context, self.C_z)
+ self.b_z)
# calculate the proposal hidden state:
s_tp = activations.tanh(
K.dot(ytm, self.W_p)
+ K.dot((rt * stm), self.U_p)
+ K.dot(context, self.C_p)
+ self.b_p)
# new hidden state:
st = (1-zt)*stm + zt * s_tp
yt = activations.softmax(
K.dot(ytm, self.W_o)
+ K.dot(stm, self.U_o)
+ K.dot(context, self.C_o)
+ self.b_o)
if self.return_probabilities:
return at, [yt, st]
else:
return yt, [yt, st]
评论列表
文章目录