def _read(w_read, memory): # w_read : (nb_sample, memory_dim) # memory : (nb_sample, memory_dim, memory_width) # return dot(w_read, memory) return T.sum(w_read[:, :, None] * memory, axis=1)