def get_output_for(self, input_, **kwargs):
W = T.tril(self.W, -1)
interactions = T.batched_dot(T.dot(input_, W), input_)
interactions = T.sqrt(T.max(interactions, 1e-6))
return self.nonlinearity(input_ + interactions)
评论列表
文章目录