def __call__(self, state):
h = state
for layer in self.hidden_layers:
h = F.relu(layer(h))
v = self.v(h)
mu = self.mu(h)
if self.scale_mu:
mu = scale_by_tanh(mu, high=self.action_space.high,
low=self.action_space.low)
mat_diag = F.exp(self.mat_diag(h))
if hasattr(self, 'mat_non_diag'):
mat_non_diag = self.mat_non_diag(h)
tril = lower_triangular_matrix(mat_diag, mat_non_diag)
mat = matmul_v3(tril, tril, transb=True)
else:
mat = F.expand_dims(mat_diag ** 2, axis=2)
return QuadraticActionValue(
mu, mat, v, min_action=self.action_space.low,
max_action=self.action_space.high)
评论列表
文章目录