def __init__(
self,
env_spec,
hidden_sizes=(32, 32),
hidden_nonlinearity=NL.tanh,
num_seq_inputs=1,
prob_network=None,
):
"""
:param env_spec: A spec for the mdp.
:param hidden_sizes: list of sizes for the fully connected hidden layers
:param hidden_nonlinearity: nonlinearity used for each hidden layer
:param prob_network: manually specified network for this policy, other network params
are ignored
:return:
"""
Serializable.quick_init(self, locals())
assert isinstance(env_spec.action_space, Discrete)
if prob_network is None:
prob_network = MLP(
input_shape=(env_spec.observation_space.flat_dim * num_seq_inputs,),
output_dim=env_spec.action_space.n,
hidden_sizes=hidden_sizes,
hidden_nonlinearity=hidden_nonlinearity,
output_nonlinearity=NL.softmax,
)
self._l_prob = prob_network.output_layer
self._l_obs = prob_network.input_layer
self._f_prob = ext.compile_function([prob_network.input_layer.input_var], L.get_output(
prob_network.output_layer))
self._dist = Categorical(env_spec.action_space.n)
super(CategoricalMLPPolicy, self).__init__(env_spec)
LasagnePowered.__init__(self, [prob_network.output_layer])
评论列表
文章目录