snn_mlp_policy.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:snn4hrl 作者: florensacc 项目源码 文件源码
def get_actions(self, observations):
        observations = np.array(observations)  # needed to do the outer product for the bilinear
        if self.latent_dim:
            if self.resample:
                latents = [self.latent_dist.sample(self.latent_dist_info) for _ in observations]
                print('resampling the latents')
            else:
                if not np.size(self.latent_fix) == self.latent_dim:  # we decide to reset based on if smthing in the fix
                    self.reset()
                if len(self.pre_fix_latent) == self.latent_dim:  # If we have a pre_fix, reset will put the latent to it
                    self.reset()  # this overwrites the latent sampled or in latent_fix
                latents = np.tile(self.latent_fix, [len(observations), 1])  # maybe a broadcast operation better...
            if self.bilinear_integration:
                extended_obs = np.concatenate([observations, latents,
                                               np.reshape(
                                                   observations[:, :, np.newaxis] * latents[:, np.newaxis, :],
                                                   (observations.shape[0], -1))],
                                              axis=1)
            else:
                extended_obs = np.concatenate([observations, latents], axis=1)
        else:
            latents = np.array([[]] * len(observations))
            extended_obs = observations
        # make mean, log_std also depend on the latents (as observ.)
        mean, log_std = self._f_dist(extended_obs)

        if self._set_std_to_0:
            actions = mean
            log_std = -1e6 * np.ones_like(log_std)
        else:
            rnd = np.random.normal(size=mean.shape)
            actions = rnd * np.exp(log_std) + mean
        return actions, dict(mean=mean, log_std=log_std, latents=latents)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号