nn.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:Sing_Par 作者: wanghm92 项目源码 文件源码
def linear_classifier(self, inputs, n_classes, add_bias=True):
    """"""

    n_dims = len(inputs.get_shape().as_list())
    batch_size = tf.shape(inputs)[0]
    bucket_size = tf.shape(inputs)[1]
    input_size = inputs.get_shape().as_list()[-1]
    output_size = n_classes
    output_shape = tf.pack([batch_size] + [bucket_size]*(n_dims-2) + [output_size])

    if self.moving_params is None:
      if self.drop_gradually:
        s = self.global_sigmoid
        keep_prob = s + (1-s)*self.mlp_keep_prob
      else:
        keep_prob = self.mlp_keep_prob
    else:
      keep_prob = 1
    if isinstance(keep_prob, tf.Tensor) or keep_prob < 1:
      noise_shape = tf.pack([batch_size] + [1]*(n_dims-2) +[input_size])
      inputs = tf.nn.dropout(inputs, keep_prob, noise_shape=noise_shape)

    inputs = tf.reshape(inputs, [-1, input_size])
    output = linalg.linear(inputs,
                    output_size,
                    add_bias=add_bias,
                    initializer=tf.zeros_initializer,
                    moving_params=self.moving_params)
    output = tf.reshape(output, output_shape)
    output.set_shape([tf.Dimension(None)]*(n_dims-1) + [tf.Dimension(output_size)])
    return output

  #=============================================================
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号