def __init__(
self,
name,
env_spec,
conv_filters, conv_filter_sizes, conv_strides, conv_pads,
hidden_sizes=[],
hidden_nonlinearity=NL.rectify,
output_nonlinearity=NL.softmax,
prob_network=None,
):
"""
:param env_spec: A spec for the mdp.
:param hidden_sizes: list of sizes for the fully connected hidden layers
:param hidden_nonlinearity: nonlinearity used for each hidden layer
:param prob_network: manually specified network for this policy, other network params
are ignored
:return:
"""
Serializable.quick_init(self, locals())
assert isinstance(env_spec.action_space, Discrete)
self._env_spec = env_spec
if prob_network is None:
prob_network = ConvNetwork(
input_shape=env_spec.observation_space.shape,
output_dim=env_spec.action_space.n,
conv_filters=conv_filters,
conv_filter_sizes=conv_filter_sizes,
conv_strides=conv_strides,
conv_pads=conv_pads,
hidden_sizes=hidden_sizes,
hidden_nonlinearity=hidden_nonlinearity,
output_nonlinearity=NL.softmax,
name="prob_network",
)
self._l_prob = prob_network.output_layer
self._l_obs = prob_network.input_layer
self._f_prob = ext.compile_function(
[prob_network.input_layer.input_var],
L.get_output(prob_network.output_layer)
)
self._dist = Categorical(env_spec.action_space.n)
super(CategoricalConvPolicy, self).__init__(env_spec)
LasagnePowered.__init__(self, [prob_network.output_layer])
categorical_conv_policy.py 文件源码
python
阅读 20
收藏 0
点赞 0
评论 0
评论列表
文章目录