rebar.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:relax 作者: duvenaud 项目源码 文件源码
def _q_func(self, samples, collection='Q_FUNC'):
    '''Returns learning signal and function.

    This is the implementation for SBNs for the ELBO.

    Args:
      samples: dictionary of sampled latent variables
      logQ: list of log q(h_i) terms
      log_likelihood_func: function used to compute log probs for the latent
        variables

    Returns:
      learning_signal: the "reward" function
      function_term: part of the function that depends on the parameters
        and needs to have the gradient taken through
    '''
    reuse=None if not self.run_q_func else True

    if self.hparams.task in ['sbn', 'omni']:
      with slim.arg_scope([slim.fully_connected],
                          weights_initializer=slim.variance_scaling_initializer(),
                          variables_collections=[collection, tf.GraphKeys.GLOBAL_VARIABLES, Q_COLLECTION]):

#        for i in reversed(xrange(self.hparams.n_layer)):
#          if i == 0:
#            n_output = self.hparams.n_input
#          else:
#            n_output = self.hparams.n_hidden
        n_output = self.hparams.n_input
        i = self.hparams.n_layer - 1  # use the last layer
        input = 2.0*samples[i]['activation']-1.0

        h = self._create_transformation(input,
                                        n_output,
                                        reuse=reuse,
                                        scope_prefix='q_func_%d' % i)
        h = tf.reduce_sum(h)

      self.run_q_func = True
      return h, h
    elif self.hparams.task == 'sp':
      with slim.arg_scope([slim.fully_connected],
                          weights_initializer=slim.variance_scaling_initializer(),
                          variables_collections=[collection, tf.GraphKeys.GLOBAL_VARIABLES, Q_COLLECTION]):
        n_output = int(self.hparams.n_input/2)
        i = self.hparams.n_layer - 1  # use the last layer
        input = 2.0*samples[i]['activation']-1.0

        h = self._create_transformation(input,
                                        n_output,
                                        reuse=reuse,
                                        scope_prefix='q_func_%d' % i)
      self.run_q_func = True
      return h, h
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号