nn.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:icml17_knn 作者: taolei87 项目源码 文件源码
def linearND(input_, output_size, scope, init_bias=0.0):
    shape = input_.get_shape().as_list()
    ndim = len(shape)
    stddev = min(1.0 / math.sqrt(shape[-1]), 0.1)
    with tf.variable_scope(scope):
        W = tf.get_variable("Matrix", [shape[-1], output_size], tf.float32, tf.random_normal_initializer(stddev=stddev))
    X_shape = tf.gather(tf.shape(input_), range(ndim-1))
    target_shape = tf.concat(0, [X_shape, [output_size]])
    exp_input = tf.reshape(input_, [-1, shape[-1]])
    if init_bias is None:
        res = tf.matmul(exp_input, W)
    else:
        with tf.variable_scope(scope):
            b = tf.get_variable("bias", [output_size], initializer=tf.constant_initializer(init_bias))
        res = tf.matmul(exp_input, W) + b
    res = tf.reshape(res, target_shape)
    res.set_shape(shape[:-1] + [output_size])
    return res
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号