def Q_func(self, state):
if state.ndim == 2:
agent_state = state[:, - self.agent_state_dim :]
market_state = state[:,:self.market_state_dim]
elif state.ndim == 3:
agent_state = state[:, :,- self.agent_state_dim :]
market_state = state[:,:,:self.market_state_dim]
a_state = Variable(agent_state)
m_state = Variable(market_state)
a = F.tanh(self.a1(a_state))
a = F.tanh(self.a2(a))
a = F.tanh(self.a3(a))
m = F.tanh(self.s1(m_state))
m = F.tanh(self.s2(m))
m = F.tanh(self.s3(m))
new_state = F.concat((a, m), axis=1)
h = F.tanh(self.fc4(new_state))
h = F.tanh(self.fc5(h))
Q = self.q_value(h)
return Q
评论列表
文章目录