network.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:policy_search_bb-alpha 作者: siemens 项目源码 文件源码
def output(self, x,indexes,samples=0,use_indices=True):
        if samples == 0:
            samples = self.n_samples

        if use_indices == True:
            self.v_z = 1e-6 + self.logistic(self.log_var_param_z[ indexes, 0 : 1 ])*(self.v_prior_z - 2e-6)
            self.m_z = self.mean_param_z[ indexes, 0 : 1 ]
            self.z = self.randomness_z[ : , indexes, : ] * T.tile(T.sqrt(self.v_z), [ samples, 1, 1 ]) + T.tile(self.m_z, [ self.n_samples, 1, 1 ])
        else:
            self.z = self.randomness_z[:,0:x.shape[1],:] *  T.tile(T.sqrt(self.v_prior_z), [samples, 1, 1 ]) 

        x = T.concatenate((x,  self.z[ : , 0 : x.shape[ 1 ], : ]), 2)

        for layer in self.layers:
            x = layer.output(x,samples)

        return x
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号