layers.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:photo-editing-tensorflow 作者: JamesChuanggg 项目源码 文件源码
def linear(input_,
           output_size,
           weights_initializer=initializers.xavier_initializer(),
           biases_initializer=tf.zeros_initializer,
           activation_fn=None,
           trainable=True,
           name='linear'):
  shape = input_.get_shape().as_list()

  if len(shape) > 2:
    input_ = tf.reshape(input_, [-1, reduce(lambda x, y: x * y, shape[1:])])
    shape = input_.get_shape().as_list()

  with tf.variable_scope(name):
    w = tf.get_variable('w', [shape[1], output_size], tf.float32,
        initializer=weights_initializer, trainable=trainable)
    b = tf.get_variable('b', [output_size],
        initializer=biases_initializer, trainable=trainable)
    out = tf.nn.bias_add(tf.matmul(input_, w), b)

    if activation_fn != None:
      return activation_fn(out), w, b
    else:
      return out, w, b
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号