def pi_and_v(self, state):
def forward(head, lstm, tail):
h = F.relu(head(state))
h = lstm(h)
return tail(h)
pout = forward(self.pi_head, self.pi_lstm, self.pi)
vout = forward(self.v_head, self.v_lstm, self.v)
return pout, vout
评论列表
文章目录