def __call__(self, inputs, state, scope=None):
with vs.variable_scope(scope or type(self).__name__):
state_rot = rotationTransform(tf.transpose(state), self._num_units, self._num_params,
self._cos_list, self._sin_list, self._nsin_list,
self._cos_idxs, self._sin_idxs, self._nsin_idxs)
state_scale, sigma = diagonalTransform(state_rot, self._num_units)
self.sigma = sigma
state_out = rotationTransform(state_scale, self._num_units, self._num_params,
self._cos_list, self._sin_list, self._nsin_list,
self._cos_idxs, self._sin_idxs, self._nsin_idxs)
state_out = tf.transpose(state_out)
input_out = linearTransformWithBias([inputs], self._num_units, bias=False)
bias = vs.get_variable(
"Bias", [self._num_units],
dtype=tf.float32,
initializer=init_ops.constant_initializer(dtype=tf.float32))
output = tf.abs(state_out + input_out + bias)
return output, output
评论列表
文章目录