linalg.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:Parser-v1 作者: tdozat 项目源码 文件源码
def linear(inputs, output_size, add_bias=True, n_splits=1, initializer=None, scope=None, moving_params=None):
  """"""

  if not isinstance(inputs, (list, tuple)):
    inputs = [inputs]
  output_size *= n_splits

  with tf.variable_scope(scope or 'Linear'):
    # Reformat the input
    total_input_size = 0
    shapes = [a.get_shape().as_list() for a in inputs]
    for shape in shapes:
      total_input_size += shape[-1]
    input_shape = tf.shape(inputs[0])
    output_shape = []
    for i in xrange(len(shapes[0])):
      output_shape.append(input_shape[i])
    output_shape[-1] = output_size
    output_shape = tf.pack(output_shape)
    for i, (input_, shape) in enumerate(zip(inputs, shapes)):
      inputs[i] = tf.reshape(input_, [-1, shape[-1]])
    concatenation = tf.concat(1, inputs)

    # Get the matrix
    if initializer is None and moving_params is None:
      mat = orthonormal_initializer(total_input_size, output_size//n_splits)
      mat = np.concatenate([mat]*n_splits, axis=1)
      initializer = tf.constant_initializer(mat)
    matrix = tf.get_variable('Weights', [total_input_size, output_size], initializer=initializer)
    if moving_params is not None:
      matrix = moving_params.average(matrix)
    else:
      tf.add_to_collection('Weights', matrix)

    # Get the bias
    if add_bias:
      bias = tf.get_variable('Biases', [output_size], initializer=tf.zeros_initializer)
      if moving_params is not None:
        bias = moving_params.average(bias)
    else:
      bias = 0

    # Do the multiplication
    new = tf.matmul(concatenation, matrix) + bias
    new = tf.reshape(new, output_shape)
    new.set_shape([tf.Dimension(None) for _ in xrange(len(shapes[0])-1)] + [tf.Dimension(output_size)])
    if n_splits > 1:
      return tf.split(len(new.get_shape().as_list())-1, n_splits, new)
    else:
      return new

#===============================================================
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号