def forward(self, x, lstm_hidden_vb=None):
p = x.view(x.size(0), self.input_dims[0] * self.input_dims[1])
p = self.rl1(self.fc1(p))
p = self.rl2(self.fc2(p))
p = self.rl3(self.fc3(p))
p = self.rl4(self.fc4(p))
p = p.view(-1, self.hidden_dim)
if self.enable_lstm:
p_, v_ = torch.split(lstm_hidden_vb[0],1)
c_p, c_v = torch.split(lstm_hidden_vb[1],1)
p, c_p = self.lstm(p, (p_, c_p))
p_out = self.policy_5(p)
sig = self.policy_sig(p)
sig = self.softplus(sig)
v = x.view(x.size(0), self.input_dims[0] * self.input_dims[1])
v = self.rl1_v(self.fc1_v(v))
v = self.rl2_v(self.fc2_v(v))
v = self.rl3_v(self.fc3_v(v))
v = self.rl4_v(self.fc4_v(v))
v = v.view(-1, self.hidden_dim)
if self.enable_lstm:
v, c_v = self.lstm_v(v, (v_, c_v))
v_out = self.value_5(v)
if self.enable_lstm:
return p_out, sig, v_out, (torch.cat((p,v),0), torch.cat((c_p, c_v),0))
else:
return p_out, sig, v_out
评论列表
文章目录