nn.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:Parser-v1 作者: tdozat 项目源码 文件源码
def linear(self, inputs, output_size, n_splits=1, add_bias=False):
    """"""

    n_dims = len(inputs.get_shape().as_list())
    batch_size = tf.shape(inputs)[0]
    bucket_size = tf.shape(inputs)[1]
    input_size = inputs.get_shape().as_list()[-1]
    output_shape = tf.pack([batch_size] + [bucket_size]*(n_dims-2) + [output_size])
    shape_to_set = [tf.Dimension(None)]*(n_dims-1) + [tf.Dimension(output_size)]

    if self.moving_params is None:
      keep_prob = self.info_keep_prob
    else:
      keep_prob = 1

    if keep_prob < 1:
      noise_shape = tf.pack([batch_size] + [1]*(n_dims-2) + [input_size])
      inputs = tf.nn.dropout(inputs, keep_prob, noise_shape=noise_shape)

    lin = linalg.linear(inputs,
                        output_size,
                        n_splits=n_splits,
                        add_bias=add_bias,
                        moving_params=self.moving_params)
    if n_splits == 1:
      lin = [lin]
    for i, split in enumerate(lin):
      split.set_shape(shape_to_set)
    if n_splits == 1:
      return lin[0]
    else:
      return lin

  #=============================================================
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号