def __init__(self, state_size):
super(BaselineCritic, self).__init__()
self.fc1 = nn.Linear(state_size, 64, bias=True)
self.fc2 = nn.Linear(64, 64, bias=True)
self.value = nn.Linear(64, 1, bias=True)
# Init
for p in [self.fc1, self.fc2, self.value]:
p.weight.data.normal_(0, 1)
p.weight.data *= 1.0 / th.sqrt(p.weight.data.pow(2).sum(1, keepdim=True))
p.bias.data.mul_(0.0)
评论列表
文章目录